diff --git a/bin/dev/release_cleanup.sh b/bin/dev/release_cleanup.sh new file mode 100755 index 00000000..abb07b6a --- /dev/null +++ b/bin/dev/release_cleanup.sh @@ -0,0 +1,40 @@ +#!/bin/sh + +# Copyright (C) 2012 The Regents of The University California. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DEVDIR="`dirname $0`" +BINDIR="`dirname $DEVDIR`" +FWDIR="`dirname $BINDIR`" + +rm -rf $FWDIR/run-tests-from-scratch-workspace +rm -rf $FWDIR/test_warehouses + +rm -rf $FWDIR/conf/shark-env.sh + +rm -rf $FWDIR/metastore_db +rm -rf $FWDIR/derby.log + +rm -rf $FWDIR/project/target $FWDIR/project/project/target + +rm -rf $FWDIR/target/resolution-cache +rm -rf $FWDIR/target/streams +rm -rf $FWDIR/target/scala-*/cache +rm -rf $FWDIR/target/scala-*/classes +rm -rf $FWDIR/target/scala-*/test-classes + +find $FWDIR -name ".DS_Store" -exec rm {} \; +find $FWDIR -name ".history" -exec rm {} \; + diff --git a/bin/dev/run-tests-from-scratch b/bin/dev/run-tests-from-scratch index 1ed30b3b..085e38bb 100755 --- a/bin/dev/run-tests-from-scratch +++ b/bin/dev/run-tests-from-scratch @@ -12,10 +12,11 @@ # Set up config vars using env vars or defaults; parse cmd line flags. ##################################################################### SHARK_PROJ_DIR_DEFAULT="$(cd `dirname $0`/../../; pwd)" +SBT_OPTS_DEFAULT="-Xms512M -Xmx2048M -Xss1M -XX:+CMSClassUnloadingEnabled -XX:MaxPermSize=512m -XX:ReservedCodeCacheSize=256m -XX:+UseCodeCacheFlushing" SPARK_MEM_DEFAULT=4g SHARK_MASTER_MEM_DEFAULT=4g SPARK_KV_JAVA_OPTS_DEFAULT=("-Dspark.local.dir=/tmp " "-Dspark.kryoserializer.buffer.mb=10 ") -SPARK_GIT_URL_DEFAULT="https://github.com/mesos/spark.git" +SPARK_GIT_URL_DEFAULT="https://github.com/apache/incubator-spark.git spark" HIVE_GIT_URL_DEFAULT="https://github.com/amplab/hive.git -b shark-0.9" SPARK_HADOOP_VERSION_DEFAULT="1.0.4" SPARK_WITH_YARN_DEFAULT=false @@ -49,6 +50,10 @@ else fi fi +if [ "x$SBT_OPTS" == "x" ] ; then + SBT_OPTS=$SBT_OPTS_DEFAULT +fi + if [ "x$SPARK_MEM" == "x" ] ; then export SPARK_MEM=$SPARK_MEM_DEFAULT fi @@ -117,6 +122,7 @@ Required Options: Optional configuration environment variables: SHARK_PROJ_DIR (default: "$SHARK_PROJ_DIR_DEFAULT") SCALA_HOME (default: Scala version ${SCALA_VERSION} will be downloaded and used) + SBT_OPTS (default: "$SBT_OPTS_DEFAULT") SPARK_MEM (default: $SPARK_MEM_DEFAULT) SHARK_MASTER_MEM (default: $SHARK_MASTER_MEM_DEFAULT) SPARK_JAVA_OPTS (default: "${SPARK_KV_JAVA_OPTS_DEFAULT[@]}") @@ -226,6 +232,7 @@ fi # Download Scala if SCALA_HOME is not specified. #################################################################### if [ "x$SCALA_HOME" == "x" ] ; then + rm -rf ./scala*tgz wget $SCALA_DOWNLOAD_PATH tar xvfz scala*tgz export SCALA_HOME="$WORKSPACE/scala-$SCALA_VERSION" @@ -251,7 +258,8 @@ else export SPARK_HADOOP_VERSION=$SPARK_HADOOP_VERSION export SPARK_WITH_YARN=$SPARK_WITH_YARN # Build spark and push the jars to local Ivy/Maven caches. - sbt/sbt clean publish-local + wget -nc http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/0.13.0/sbt-launch.jar + java $SBT_OPTS -jar sbt-launch.jar clean publish-local popd fi export SPARK_HOME="$WORKSPACE/spark" @@ -274,17 +282,11 @@ export HADOOP_HOME="$WORKSPACE/hadoop-${SPARK_HADOOP_VERSION}" # Download and build Hive. ##################################################################### if $SKIP_HIVE ; then - if [ ! -e "hive" -o ! -e "hive-warehouse" ] ; then - echo "hive and hive-warehouse dirs must exist when skipping Hive download and build stage." + if [ ! -e "hive" ] ; then + echo "hive dir must exist when skipping Hive download and build stage." exit -1 fi else - # Setup the Hive warehouse directory. - HIVE_WAREHOUSE=./hive-warehouse - rm -rf $HIVE_WAREHOUSE - mkdir -p $HIVE_WAREHOUSE - chmod 0777 $HIVE_WAREHOUSE - rm -rf hive git clone $HIVE_GIT_URL pushd hive diff --git a/bin/ext/sharkserver.sh b/bin/ext/sharkserver.sh index e93aadee..de4c08a8 100644 --- a/bin/ext/sharkserver.sh +++ b/bin/ext/sharkserver.sh @@ -18,10 +18,6 @@ THISSERVICE=sharkserver export SERVICE_LIST="${SERVICE_LIST}${THISSERVICE} " -# Use Java to launch Shark otherwise the unit tests cannot properly kill -# the server process. -export SHARK_LAUNCH_WITH_JAVA=1 - sharkserver() { echo "Starting the Shark Server" exec $FWDIR/run shark.SharkServer "$@" diff --git a/conf/blinkdb-env.sh.template b/conf/blinkdb-env.sh.template index a7a42cc9..deb14f2b 100755 --- a/conf/blinkdb-env.sh.template +++ b/conf/blinkdb-env.sh.template @@ -39,6 +39,11 @@ export HIVE_HOME="" # Only required if using Mesos: #export MESOS_NATIVE_LIBRARY=/usr/local/lib/libmesos.so +# Only required if run shark with spark on yarn +#export SHARK_EXEC_MODE=yarn +#export SPARK_ASSEMBLY_JAR= +#export SHARK_ASSEMBLY_JAR= + # (Optional) Extra classpath #export SPARK_LIBRARY_PATH="" diff --git a/project/SharkBuild.scala b/project/SharkBuild.scala index f7903d5d..6f2ba46d 100755 --- a/project/SharkBuild.scala +++ b/project/SharkBuild.scala @@ -21,23 +21,32 @@ import Keys._ import sbtassembly.Plugin._ import AssemblyKeys._ +import scala.util.Properties.{ envOrNone => env } + object SharkBuild extends Build { val BLINKDB_VERSION = "0.1.0-SNAPSHOT" // Shark version - val SHARK_VERSION = "0.8.0-SNAPSHOT" + val SHARK_VERSION = "0.9.0-hive0.9-SNAPSHOT" - val SPARK_VERSION = "0.8.0-SNAPSHOT" + val SPARK_VERSION = "0.9.0-incubating" - val SCALA_VERSION = "2.9.3" + val SCALA_VERSION = "2.10.3" // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or // "1.0.1" for Apache releases, or "0.20.2-cdh3u3" for Cloudera Hadoop. - val HADOOP_VERSION = "1.0.4" + val DEFAULT_HADOOP_VERSION = "1.0.4" + + lazy val hadoopVersion = env("SHARK_HADOOP_VERSION") orElse + env("SPARK_HADOOP_VERSION") getOrElse + DEFAULT_HADOOP_VERSION + + // Whether to build Shark with Yarn support + val YARN_ENABLED = env("SHARK_YARN").getOrElse("false").toBoolean // Whether to build Shark with Tachyon jar. - val TACHYON_ENABLED = false + val TACHYON_ENABLED = true lazy val root = Project( id = "root", @@ -47,6 +56,10 @@ object SharkBuild extends Build { val excludeKyro = ExclusionRule(organization = "de.javakaffee") val excludeHadoop = ExclusionRule(organization = "org.apache.hadoop") val excludeNetty = ExclusionRule(organization = "org.jboss.netty") + val excludeCurator = ExclusionRule(organization = "org.apache.curator") + val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson") + val excludeAsm = ExclusionRule(organization = "asm") + val excludeSnappy = ExclusionRule(organization = "org.xerial.snappy") def coreSettings = Defaults.defaultSettings ++ Seq( @@ -54,15 +67,13 @@ object SharkBuild extends Build { organization := "edu.berkeley.cs.amplab", version := SHARK_VERSION, scalaVersion := SCALA_VERSION, - scalacOptions := Seq("-deprecation", "-unchecked", "-optimize"), + scalacOptions := Seq("-deprecation", "-unchecked", "-optimize", "-feature", "-Yinline-warnings"), parallelExecution in Test := false, // Download managed jars into lib_managed. retrieveManaged := true, resolvers ++= Seq( "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/", - "JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/", - "Spray Repository" at "http://repo.spray.cc/", "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/", "Local Maven" at Path.userHome.asFile.toURI.toURL + ".m2/repository" ), @@ -70,6 +81,9 @@ object SharkBuild extends Build { fork := true, javaOptions += "-XX:MaxPermSize=512m", javaOptions += "-Xmx2g", + javaOptions += "-Dsun.io.serialization.extendedDebugInfo=true", + + testOptions in Test += Tests.Argument("-oF"), // Full stack trace on test failures testListeners <<= target.map( t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), @@ -102,7 +116,7 @@ object SharkBuild extends Build { "org.apache.spark" %% "spark-core" % SPARK_VERSION, "org.apache.spark" %% "spark-repl" % SPARK_VERSION, "com.google.guava" % "guava" % "14.0.1", - "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll(excludeNetty), + "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm) force(), // See https://code.google.com/p/guava-libraries/issues/detail?id=1095 "com.google.code.findbugs" % "jsr305" % "1.3.+", @@ -114,14 +128,14 @@ object SharkBuild extends Build { // Test infrastructure "org.scalatest" %% "scalatest" % "1.9.1" % "test", "junit" % "junit" % "4.10" % "test", - "net.java.dev.jets3t" % "jets3t" % "0.9.0", + "net.java.dev.jets3t" % "jets3t" % "0.7.1", "com.novocode" % "junit-interface" % "0.8" % "test") ++ - (if (TACHYON_ENABLED) Some("org.tachyonproject" % "tachyon" % "0.3.0-SNAPSHOT" excludeAll(excludeKyro, excludeHadoop) ) else None).toSeq - ) + (if (YARN_ENABLED) Some("org.apache.spark" %% "spark-yarn" % SPARK_VERSION) else None).toSeq ++ + (if (TACHYON_ENABLED) Some("org.tachyonproject" % "tachyon" % "0.3.0" excludeAll(excludeKyro, excludeHadoop, excludeCurator, excludeJackson, excludeNetty, excludeAsm)) else None).toSeq + ) ++ org.scalastyle.sbt.ScalastylePlugin.Settings def assemblyProjSettings = Seq( - name := "shark-assembly", - jarName in assembly <<= version map { v => "shark-assembly-" + v + "-hadoop" + HADOOP_VERSION + ".jar" } + jarName in assembly <<= version map { v => "shark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" } ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq( @@ -129,6 +143,7 @@ object SharkBuild extends Build { mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard + case "META-INF/services/org.apache.hadoop.fs.FileSystem" => MergeStrategy.concat case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first } diff --git a/project/plugins.sbt b/project/plugins.sbt index d06b220d..5b2a7785 100755 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -15,6 +15,8 @@ addSbtPlugin("org.ensime" % "ensime-sbt-cmd" % "0.1.1") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.3.2") + addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.4.0") @@ -24,3 +26,5 @@ addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.9.2") resolvers += Resolver.url( "sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) + +resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" diff --git a/run b/run index 2a038411..a7688082 100755 --- a/run +++ b/run @@ -1,9 +1,9 @@ #!/bin/bash # This file is used to launch Shark on the master. -export SCALA_VERSION=2.9.3 -SHARK_VERSION=0.8.0-SNAPSHOT -BLINKDB_VERSION=0.1.0-SNAPSHOT +export SCALA_VERSION=2.10 +SHARK_VERSION=0.9.0-SNAPSHOT +BLINKDB_VERSION=0.2.0-SNAPSHOT # Figure out where the framework is installed FWDIR="$(cd `dirname $0`; pwd)" @@ -48,6 +48,26 @@ if [ -n "$MASTER" ] ; then fi fi +# check for shark with spark on yarn params +if [ "x$SHARK_EXEC_MODE" == "xyarn" ] ; then + if [ "x$SPARK_ASSEMBLY_JAR" == "x" ] ; then + echo "No SPARK_ASSEMBLY_JAR specified. Please set SPARK_ASSEMBLY_JAR for spark on yarn mode." + exit 1 + else + export SPARK_JAR=$SPARK_ASSEMBLY_JAR + fi + + if [ "x$SHARK_ASSEMBLY_JAR" == "x" ] ; then + echo "No SHARK_ASSEMBLY_JAR specified. please set SHARK_ASSEMBLY_JAR for spark on yarn mode." + exit 1 + else + export SPARK_YARN_APP_JAR = $SHARK_ASSEMBLY_JAR + fi + + # use yarn-client mode for interactive shell. + export MASTER=yarn-client +fi + # Check for optionally specified configuration file path if [ "x$HIVE_CONF_DIR" == "x" ] ; then HIVE_CONF_DIR="$HIVE_HOME/conf" @@ -110,9 +130,10 @@ SPARK_CLASSPATH+=":$SHARK_HOME/target/scala-$SCALA_VERSION/test-classes" if [ "x$HADOOP_HOME" == "x" ] ; then - echo "No HADOOP_HOME specified. Shark will run in local-mode" + echo "No HADOOP_HOME specified. Shark will run in local-mode" else - SPARK_CLASSPATH+=:$HADOOP_HOME/conf + SPARK_CLASSPATH+=:$HADOOP_HOME/etc/hadoop + SPARK_CLASSPATH+=:$HADOOP_HOME/conf fi @@ -141,22 +162,16 @@ export JAVA_OPTS export ANT_OPTS=$JAVA_OPTS if [ "x$RUNNER" == "x" ] ; then - if [ "$SHARK_LAUNCH_WITH_JAVA" == "1" ]; then - CLASSPATH+=":$SCALA_HOME/lib/scala-library.jar" - CLASSPATH+=":$SCALA_HOME/lib/scala-compiler.jar" - CLASSPATH+=":$SCALA_HOME/lib/jline.jar" - if [ -n "$JAVA_HOME" ]; then - RUNNER="${JAVA_HOME}/bin/java" - else - RUNNER=java - fi - # The JVM doesn't read JAVA_OPTS by default so we need to pass it in - EXTRA_ARGS="$JAVA_OPTS" + CLASSPATH+=":$SCALA_HOME/lib/scala-library.jar" + CLASSPATH+=":$SCALA_HOME/lib/scala-compiler.jar" + CLASSPATH+=":$SCALA_HOME/lib/jline.jar" + if [ -n "$JAVA_HOME" ]; then + RUNNER="${JAVA_HOME}/bin/java" else - SCALA=${SCALA_HOME}/bin/scala - RUNNER="$SCALA -cp \"$CLASSPATH\"" - EXTRA_ARGS="" + RUNNER=java fi + # The JVM doesn't read JAVA_OPTS by default so we need to pass it in + EXTRA_ARGS="$JAVA_OPTS" fi exec $RUNNER $EXTRA_ARGS "$@" diff --git a/sbt/sbt b/sbt/sbt index 29657c61..78f994f6 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -5,9 +5,9 @@ if [ -e $BLINKDB_CONF_DIR/blinkdb-env.sh ] ; then . $BLINKDB_CONF_DIR/blinkdb-env.sh fi -if [[ "$@" == *"test"* ]]; then - if [ "x$HIVE_DEV_HOME" == "x" ]; then - echo "No HIVE_DEV_HOME specified. Required for tests. Please set HIVE_DEV_HOME." +if [[ "$@" == *"test"* ]] || [[ "$@" == "eclipse" ]]; then + if [[ "x$HIVE_DEV_HOME" == "x" ]]; then + echo "No HIVE_DEV_HOME specified. Required for tests and eclipse. Please set HIVE_DEV_HOME." exit 1 fi fi diff --git a/scalastyle-config.xml b/scalastyle-config.xml new file mode 100644 index 00000000..d3b75788 --- /dev/null +++ b/scalastyle-config.xml @@ -0,0 +1,125 @@ + + + + + + + + + Scalastyle standard configuration + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/main/java/shark/execution/ExplainTaskHelper.java b/src/main/java/shark/execution/ExplainTaskHelper.java index b85bce50..933ebb62 100755 --- a/src/main/java/shark/execution/ExplainTaskHelper.java +++ b/src/main/java/shark/execution/ExplainTaskHelper.java @@ -58,11 +58,11 @@ public static void outputPlan(Serializable work, PrintStream out, boolean extend // conf and then // the children if (work instanceof shark.execution.Operator) { - shark.execution.Operator> operator = - (shark.execution.Operator>) work; + shark.execution.Operator operator = + (shark.execution.Operator) work; out.println(indentString(indent) + "**" + operator.getClass().getName()); - if (operator.hiveOp().getConf() != null) { - outputPlan(operator.hiveOp().getConf(), out, extended, indent); + if (operator.desc() != null) { + outputPlan(operator.desc(), out, extended, indent); } if (operator.parentOperators() != null) { for (shark.execution.Operator op : operator.parentOperatorsAsJavaList()) { diff --git a/src/main/java/shark/execution/ReduceSinkOperatorHelper.java b/src/main/java/shark/execution/ReduceSinkOperatorHelper.java deleted file mode 100755 index e7f915fd..00000000 --- a/src/main/java/shark/execution/ReduceSinkOperatorHelper.java +++ /dev/null @@ -1,31 +0,0 @@ -package shark.execution; - -import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; -import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.ql.metadata.HiveException; - -import java.util.List; - - -/** - * This class is used because we cannot call protected static methods from - * Scala. - */ -@SuppressWarnings("serial") -public class ReduceSinkOperatorHelper extends ReduceSinkOperator { - - public static StructObjectInspector initEvaluatorsAndReturnStruct( - ExprNodeEvaluator[] evals, List> distinctColIndices, - List outputColNames, int length, ObjectInspector rowInspector) - throws HiveException { - - return ReduceSinkOperator.initEvaluatorsAndReturnStruct( - evals, - distinctColIndices, - outputColNames, - length, - rowInspector); - } -} diff --git a/src/main/java/shark/tgf/Schema.java b/src/main/java/shark/tgf/Schema.java new file mode 100644 index 00000000..c571a15f --- /dev/null +++ b/src/main/java/shark/tgf/Schema.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2013 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.tgf; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.ElementType; +import java.lang.annotation.Target; + + +/** + * Schema annotation for TGFs, example syntax: @Schema(spec = "name string, age int") + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface Schema { + String spec(); +} diff --git a/src/main/resources/tablerdd/SharkContext_sqlRdd_generator.py b/src/main/resources/tablerdd/SharkContext_sqlRdd_generator.py new file mode 100755 index 00000000..0f33ca71 --- /dev/null +++ b/src/main/resources/tablerdd/SharkContext_sqlRdd_generator.py @@ -0,0 +1,24 @@ +#!/usr/bin/python +from string import Template +import sys + +from generator_utils import * + +## This script generates functions sqlRdd for SharkContext.scala + +p = sys.stdout + +# The SharkContext declarations +for x in range(2,23): + sqlRddFun = Template( +""" + def sqlRdd[$list1](cmd: String): + RDD[Tuple$num[$list2]] = { + new TableRDD$num[$list2](sql2rdd(cmd), + Seq($list3)) + } +""").substitute(num = x, + list1 = createList(1, x, "T", ": M", ", ", 80, 4), + list2 = createList(1, x, "T", sep=", ", indent = 4), + list3 = createList(1, x, "m[T", "]", sep=", ", indent = 10)) + p.write(sqlRddFun) diff --git a/src/main/resources/tablerdd/TableRDDGenerated_generator.py b/src/main/resources/tablerdd/TableRDDGenerated_generator.py new file mode 100755 index 00000000..45deec03 --- /dev/null +++ b/src/main/resources/tablerdd/TableRDDGenerated_generator.py @@ -0,0 +1,91 @@ +#!/usr/bin/python +from string import Template +import sys +from generator_utils import * + +## This script generates TableRDDGenerated.scala + +p = sys.stdout + +p.write( +""" +/* + * Copyright (C) 2013 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +package shark.api + +// *** This file is auto-generated from TableRDDGenerated_generator.py *** +import scala.language.implicitConversions +import org.apache.spark.rdd.RDD +import org.apache.spark.{TaskContext, Partition} + +import scala.reflect.ClassTag + +class TableSeqRDD(prev: TableRDD) + extends RDD[Seq[Any]](prev) { + + def getSchema = prev.schema + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): Iterator[Seq[Any]] = { + prev.compute(split, context).map( row => + (0 until prev.schema.size).map(i => row.getPrimitive(i)) ) + } +} + +""") + +for x in range(1,23): + + inner = "" + for y in range(1,x+1): + if y % 3 == 1: inner += " " + inner += Template(" row.getPrimitiveGeneric[T$num1]($num2)").substitute(num1=y, num2=y-1) + if y != x: inner += "," + if y % 3 == 0: inner += "\n" + inner += " ) )\n" + + tableClass = Template( +""" +class TableRDD$num[$list](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple$num[$list]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == $num, "Table only has " + tableCols + " columns, expecting $num") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple$num[$list]] = { + prev.compute(split, context).map( row => + new Tuple$num[$list]( + $innerfatlist + } +} +""").substitute(num = x, list = createList(1, x, "T", "", ", ", indent=4), innerfatlist = inner) + + + p.write(tableClass) diff --git a/src/main/resources/tablerdd/generator_utils.py b/src/main/resources/tablerdd/generator_utils.py new file mode 100644 index 00000000..26cdb487 --- /dev/null +++ b/src/main/resources/tablerdd/generator_utils.py @@ -0,0 +1,18 @@ +#!/usr/bin/python +import sys + +# e.g. createList(1,3, "T[", "]", ",") gives T[1],T[2],T[3] +def createList(start, stop, prefix, suffix="", sep = ",", newlineAfter = 70, indent = 0): + res = "" + oneLine = res + for y in range(start,stop+1): + res += prefix + str(y) + suffix + oneLine += prefix + str(y) + suffix + if y != stop: + res += sep + oneLine += sep + if len(oneLine) > newlineAfter: + res += "\n" + " "*indent + oneLine = "" + return res + diff --git a/src/main/resources/tablerdd/rddtable_generator.py b/src/main/resources/tablerdd/rddtable_generator.py new file mode 100755 index 00000000..eda23d05 --- /dev/null +++ b/src/main/resources/tablerdd/rddtable_generator.py @@ -0,0 +1,97 @@ +#!/usr/bin/python +from string import Template +import sys +from generator_utils import * + +## This script generates RDDtable.scala + +p = sys.stdout + +# e.g. createList(1,3, "T[", "]", ",") gives T[1],T[2],T[3] +def createList(start, stop, prefix, suffix="", sep = ",", newlineAfter = 70, indent = 0): + res = "" + oneLine = res + for y in range(start,stop+1): + res += prefix + str(y) + suffix + oneLine += prefix + str(y) + suffix + if y != stop: + res += sep + oneLine += sep + if len(oneLine) > newlineAfter: + res += "\n" + " "*indent + oneLine = "" + return res + +### The SparkContext declaration + +prefix = """ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.api + +// *** This file is auto-generated from RDDTable_generator.py *** + +import scala.language.implicitConversions +import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD + +object RDDTableImplicits { + private type C[T] = ClassTag[T] + +""" + +p.write(prefix) + +for x in range(2,23): + + tableClass = Template( +""" + implicit def rddToTable$num[$tmlist] + (rdd: RDD[($tlist)]): RDDTableFunctions = RDDTable(rdd) + +""").substitute(num = x, tmlist = createList(1, x, "T", ": C", ", ", indent=4), tlist = createList(1, x, "T", "", ", ", indent=4)) + p.write(tableClass) + +prefix = """ +} + +object RDDTable { + + private type C[T] = ClassTag[T] + private def ct[T](implicit c: ClassTag[T]) = c +""" + +p.write(prefix) + +for x in range(2,23): + + tableClass = Template( +""" + def apply[$tmlist] + (rdd: RDD[($tlist)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq($mtlist)) + } + +""").substitute(tmlist = createList(1, x, "T", ": C", ", ", indent=4), tlist = createList(1, x, "T", "", ", ", indent=4), + mtlist = createList(1, x, "ct[T", "]", ", ", indent=4)) + p.write(tableClass) + + +p.write("}\n") diff --git a/src/main/scala/shark/CachedTableRecovery.scala b/src/main/scala/shark/CachedTableRecovery.scala deleted file mode 100644 index a5ec5e58..00000000 --- a/src/main/scala/shark/CachedTableRecovery.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (C) 2012 The Regents of The University California. - * All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package shark - -import scala.collection.JavaConversions.asScalaBuffer - -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.metadata.Table - - -/** - * Singleton representing access to the Shark Meta data that gets applied to cached tables - * in the Hive Meta Store. - * All cached tables are tagged with a property CTAS_QUERY_STRING whose value - * represents the query that led to the creation of the cached table. - * This is used to reload RDDs upon server restarts. - */ -object CachedTableRecovery extends LogHelper { - - val db = Hive.get(new HiveConf) - - val QUERY_STRING = "CTAS_QUERY_STRING" - - /** - * Load the cached tables into memory. - * @param cmdRunner , the runner that is responsible - * for taking a cached table query and - * a) create the table metadata in Hive Meta Store - * b) load the table as an RDD in memory - * @see SharkServer for an example usage. - */ - def loadAsRdds(cmdRunner: String => Unit) { - getMeta.foreach { t => - try { - db.dropTable(t._1) - cmdRunner(t._2) - } catch { - case e: Exception => logError("Failed to reload cache table " + t._1, e) - } - } - } - - /** - * Updates the Hive metastore, with cached table metadata. - * The cached table metadata is stored in the Hive metastore - * of each cached table, as a key value pair, the key being - * CTAS_QUERY_STRING and the value being the cached table query itself. - * - * @param cachedTableQueries , a collection of pairs of the form - * (cached table name, cached table query). - */ - def updateMeta(cachedTableQueries : Iterable[(String, String)]): Unit = { - cachedTableQueries.foreach { x => - val newTbl = new Table(db.getTable(x._1).getTTable()) - newTbl.setProperty(QUERY_STRING, x._2) - db.alterTable(x._1, newTbl) - } - } - - /** - * Returns all the Cached table metadata present in the Hive Meta store. - * - * @return sequence of pairs, each pair representing the cached table name - * and the cached table query. - */ - def getMeta(): Seq[(String, String)] = { - db.getAllTables().foldLeft(List[(String,String)]())((curr, tableName) => { - val tbl = db.getTable(tableName) - Option(tbl.getProperty(QUERY_STRING)) match { - case Some(q) => curr.::(tableName, q) - case None => curr - } - }) - } -} \ No newline at end of file diff --git a/src/main/scala/shark/SharkCliDriver.scala b/src/main/scala/shark/SharkCliDriver.scala index c7c4f735..b7091466 100755 --- a/src/main/scala/shark/SharkCliDriver.scala +++ b/src/main/scala/shark/SharkCliDriver.scala @@ -25,37 +25,62 @@ import java.io.PrintStream import java.io.UnsupportedEncodingException import java.net.URLClassLoader import java.util.ArrayList -import jline.{History, ConsoleReader} + import scala.collection.JavaConversions._ +import jline.{History, ConsoleReader} + import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} -import org.apache.hadoop.hive.common.LogUtils +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} import org.apache.hadoop.hive.common.LogUtils.LogInitializationException import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver -import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, Utilities} -import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.parse.ParseDriver +import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.hadoop.io.IOUtils +import org.apache.thrift.transport.TSocket -import org.apache.spark.SparkContext +import shark.memstore2.TableRecovery object SharkCliDriver { - - var prompt = "blinkdb" - var prompt2 = " " // when ';' is not yet seen. + val SKIP_RDD_RELOAD_FLAG = "-skipRddReload" + + private var prompt = "blinkdb" + private var prompt2 = " " // when ';' is not yet seen. + private var transport:TSocket = _ + + installSignalHandler() + + /** + * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(), + * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while + * a command is being processed by the current thread. + */ + def installSignalHandler() { + HiveInterruptUtils.add(new HiveInterruptCallback { + override def interrupt() { + // Handle remote execution mode + if (SharkEnv.sc != null) { + SharkEnv.sc.cancelAllJobs() + } else { + if (transport != null) { + // Force closing of TCP connection upon session termination + transport.getSocket().close() + } + } + } + }) + } def main(args: Array[String]) { - val hiveArgs = args.filterNot(_.equals("-loadRdds")) - val loadRdds = hiveArgs.length < args.length + val hiveArgs = args.filterNot(_.equals(SKIP_RDD_RELOAD_FLAG)) + val reloadRdds = hiveArgs.length == args.length val oproc = new OptionsProcessor() if (!oproc.process_stage1(hiveArgs)) { System.exit(1) @@ -73,11 +98,11 @@ object SharkCliDriver { logInitDetailMessage = e.getMessage() } - var ss = new CliSessionState(new HiveConf(classOf[SessionState])) + val ss = new CliSessionState(new HiveConf(classOf[SessionState])) ss.in = System.in try { ss.out = new PrintStream(System.out, true, "UTF-8") - ss.info = new PrintStream(System.err, true, "UTF-8"); + ss.info = new PrintStream(System.err, true, "UTF-8") ss.err = new PrintStream(System.err, true, "UTF-8") } catch { case e: UnsupportedEncodingException => System.exit(3) @@ -134,7 +159,7 @@ object SharkCliDriver { Thread.currentThread().setContextClassLoader(loader) } - var cli = new SharkCliDriver(loadRdds) + val cli = new SharkCliDriver(reloadRdds) cli.setHiveVariables(oproc.getHiveVariables()) // Execute -i init files (always in silent mode) @@ -154,7 +179,7 @@ object SharkCliDriver { System.exit(3) } - var reader = new ConsoleReader() + val reader = new ConsoleReader() reader.setBellEnabled(false) // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) reader.addCompletor(CliDriver.getCommandCompletor()) @@ -163,7 +188,7 @@ object SharkCliDriver { val HISTORYFILE = ".hivehistory" val historyDirectory = System.getProperty("user.home") try { - if ((new File(historyDirectory)).exists()) { + if (new File(historyDirectory).exists()) { val historyFile = historyDirectory + File.separator + HISTORYFILE reader.setHistory(new History(new File(historyFile))) } else { @@ -186,10 +211,15 @@ object SharkCliDriver { "spacesForString", classOf[String]) spacesForStringMethod.setAccessible(true) + val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") + clientTransportTSocketField.setAccessible(true) + + transport = clientTransportTSocketField.get(ss).asInstanceOf[TSocket] + var ret = 0 var prefix = "" - var curDB = getFormattedDbMethod.invoke(null, conf, ss).asInstanceOf[String] + val curDB = getFormattedDbMethod.invoke(null, conf, ss).asInstanceOf[String] var curPrompt = SharkCliDriver.prompt + curDB var dbSpaces = spacesForStringMethod.invoke(null, curDB).asInstanceOf[String] @@ -200,7 +230,7 @@ object SharkCliDriver { } if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { line = prefix + line - ret = cli.processLine(line) + ret = cli.processLine(line, true) prefix = "" val sharkMode = SharkConfVars.getVar(conf, SharkConfVars.EXEC_MODE) == "shark" curPrompt = if (sharkMode) SharkCliDriver.prompt else CliDriver.prompt @@ -216,13 +246,13 @@ object SharkCliDriver { ss.close() System.exit(ret) - } + } // end of main } -class SharkCliDriver(loadRdds: Boolean = false) extends CliDriver with LogHelper { +class SharkCliDriver(reloadRdds: Boolean = true) extends CliDriver with LogHelper { - private val ss = SessionState.get() + private val ss = SessionState.get().asInstanceOf[CliSessionState] private val LOG = LogFactory.getLog("CliDriver") @@ -230,13 +260,19 @@ class SharkCliDriver(loadRdds: Boolean = false) extends CliDriver with LogHelper private val conf: Configuration = if (ss != null) ss.getConf() else new Configuration() - SharkConfVars.initializeWithDefaults(conf); + SharkConfVars.initializeWithDefaults(conf) // Force initializing SharkEnv. This is put here but not object SharkCliDriver // because the Hive unit tests do not go through the main() code path. - SharkEnv.init() - - if(loadRdds) CachedTableRecovery.loadAsRdds(processCmd(_)) + if (!ss.isRemoteMode()) { + SharkEnv.init() + if (reloadRdds) { + console.printInfo( + "Reloading cached RDDs from previous Shark sessions... (use %s flag to skip reloading)" + .format(SharkCliDriver.SKIP_RDD_RELOAD_FLAG)) + TableRecovery.reloadRdds(processCmd(_), Some(console)) + } + } def this() = this(false) @@ -307,7 +343,7 @@ class SharkCliDriver(loadRdds: Boolean = false) extends CliDriver with LogHelper try { while (!out.checkError() && qp.getResults(res)) { - res.foreach(out.println(_)) + res.foreach(line => out.println(line)) res.clear() } } catch { diff --git a/src/main/scala/shark/SharkConfVars.scala b/src/main/scala/shark/SharkConfVars.scala index 0506ae77..98bc7ed9 100755 --- a/src/main/scala/shark/SharkConfVars.scala +++ b/src/main/scala/shark/SharkConfVars.scala @@ -17,6 +17,8 @@ package shark +import scala.language.existentials + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.conf.HiveConf @@ -34,18 +36,15 @@ object SharkConfVars { val COLUMNAR_COMPRESSION = new ConfVar("shark.column.compress", true) + // If true, then cache any table whose name ends in "_cached". + val CHECK_TABLENAME_FLAG = new ConfVar("shark.cache.flag.checkTableName", true) + // Specify the initial capacity for ArrayLists used to represent columns in columnar // cache. The default -1 for non-local mode means that Shark will try to estimate // the number of rows by using: partition_size / (num_columns * avg_field_size). val COLUMN_BUILDER_PARTITION_SIZE = new ConfVar("shark.column.partitionSize.mb", if (System.getenv("MASTER") == null) 1 else -1) - // Default storage level for cached tables. - val STORAGE_LEVEL = new ConfVar("shark.cache.storageLevel", "MEMORY_AND_DISK") - - // If true, then cache any table whose name ends in "_cached". - val CHECK_TABLENAME_FLAG = new ConfVar("shark.cache.flag.checkTableName", true) - // Prune map splits for cached tables based on predicates in queries. val MAP_PRUNING = new ConfVar("shark.mappruning", true) @@ -68,7 +67,8 @@ object SharkConfVars { conf.set(EXPLAIN_MODE.varname, EXPLAIN_MODE.defaultVal) } if (conf.get(COLUMN_BUILDER_PARTITION_SIZE.varname) == null) { - conf.setInt(COLUMN_BUILDER_PARTITION_SIZE.varname, COLUMN_BUILDER_PARTITION_SIZE.defaultIntVal) + conf.setInt(COLUMN_BUILDER_PARTITION_SIZE.varname, + COLUMN_BUILDER_PARTITION_SIZE.defaultIntVal) } if (conf.get(COLUMNAR_COMPRESSION.varname) == null) { conf.setBoolean(COLUMNAR_COMPRESSION.varname, COLUMNAR_COMPRESSION.defaultBoolVal) @@ -173,18 +173,18 @@ case class ConfVar( } def this(varname: String, defaultVal: Int) = { - this(varname, classOf[Int], null, defaultVal, 0, 0, false) + this(varname, classOf[Int], defaultVal.toString, defaultVal, 0, 0, false) } def this(varname: String, defaultVal: Long) = { - this(varname, classOf[Long], null, 0, defaultVal, 0, false) + this(varname, classOf[Long], defaultVal.toString, 0, defaultVal, 0, false) } def this(varname: String, defaultVal: Float) = { - this(varname, classOf[Float], null, 0, 0, defaultVal, false) + this(varname, classOf[Float], defaultVal.toString, 0, 0, defaultVal, false) } def this(varname: String, defaultVal: Boolean) = { - this(varname, classOf[Boolean], null, 0, 0, 0, defaultVal) + this(varname, classOf[Boolean], defaultVal.toString, 0, 0, 0, defaultVal) } } diff --git a/src/main/scala/shark/SharkContext.scala b/src/main/scala/shark/SharkContext.scala index 3f896211..b20847f3 100755 --- a/src/main/scala/shark/SharkContext.scala +++ b/src/main/scala/shark/SharkContext.scala @@ -22,33 +22,50 @@ import java.util.{ArrayList => JArrayList} import scala.collection.Map import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.common.LogUtils -import org.apache.hadoop.hive.common.LogUtils.LogInitializationException import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessor import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkConf, SparkContext, SparkEnv} +import org.apache.spark.scheduler.SplitInfo +import org.apache.spark.rdd.RDD import shark.api._ +import shark.tgf.TGF class SharkContext( - master: String, - jobName: String, - sparkHome: String, - jars: Seq[String], - environment: Map[String, String]) - extends SparkContext(master, jobName, sparkHome, jars, environment) { + config: SparkConf, + preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) + extends SparkContext(config, preferredNodeLocationData) { + import SharkContext._ @transient val sparkEnv = SparkEnv.get + private type C[T] = ClassTag[T] + private def ct[T](implicit c : ClassTag[T]) = c + SharkContext.init() - import SharkContext._ + + def this( + master: String, + jobName: String, + sparkHome: String, + jars: Seq[String], + environment: Map[String, String]) { + this( + (new SparkConf()) + .setMaster(master) + .setAppName(jobName) + .setSparkHome(sparkHome) + .setJars(jars) + .setExecutorEnv(environment.toSeq)) + } /** * Execute the command and return the results as a sequence. Each element @@ -105,6 +122,174 @@ class SharkContext( } } + /** + * Execute a SQL command and return the results as a RDD of Seq. The SQL command must be + * a SELECT statement. This is useful if the table has more than 22 columns (more than fits in tuples) + * NB: These are auto-generated using resources/tablerdd/SharkContext_sqlRdd_generator.py + */ + def sqlSeqRdd(cmd: String): RDD[Seq[Any]] = { + new TableSeqRDD(sql2rdd(cmd)) + } + + /** + * Execute a SQL command and return the results as a RDD of Tuple. The SQL command must be + * a SELECT statement. + */ + + def sqlRdd[T1: C, T2: C](cmd: String): + RDD[Tuple2[T1, T2]] = { + new TableRDD2[T1, T2](sql2rdd(cmd), + Seq(ct[T1], ct[T2])) + } + + def sqlRdd[T1: C, T2: C, T3: C](cmd: String): + RDD[Tuple3[T1, T2, T3]] = { + new TableRDD3[T1, T2, T3](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C](cmd: String): + RDD[Tuple4[T1, T2, T3, T4]] = { + new TableRDD4[T1, T2, T3, T4](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C](cmd: String): + RDD[Tuple5[T1, T2, T3, T4, T5]] = { + new TableRDD5[T1, T2, T3, T4, T5](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C](cmd: String): + RDD[Tuple6[T1, T2, T3, T4, T5, T6]] = { + new TableRDD6[T1, T2, T3, T4, T5, T6](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C](cmd: String): + RDD[Tuple7[T1, T2, T3, T4, T5, T6, T7]] = { + new TableRDD7[T1, T2, T3, T4, T5, T6, T7](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C](cmd: String): + RDD[Tuple8[T1, T2, T3, T4, T5, T6, T7, T8]] = { + new TableRDD8[T1, T2, T3, T4, T5, T6, T7, T8](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C](cmd: String): + RDD[Tuple9[T1, T2, T3, T4, T5, T6, T7, T8, T9]] = { + new TableRDD9[T1, T2, T3, T4, T5, T6, T7, T8, T9](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C](cmd: String): + RDD[Tuple10[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]] = { + new TableRDD10[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C](cmd: String): + RDD[Tuple11[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]] = { + new TableRDD11[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C](cmd: String): + RDD[Tuple12[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]] = { + new TableRDD12[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C](cmd: String): + RDD[Tuple13[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]] = { + new TableRDD13[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C, T14: C](cmd: String): + RDD[Tuple14[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]] = { + new TableRDD14[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13], ct[T14])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C, T14: C, T15: C](cmd: String): + RDD[Tuple15[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]] = { + new TableRDD15[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13], ct[T14], ct[T15])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C, T14: C, T15: C, T16: C](cmd: String): + RDD[Tuple16[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16]] = { + new TableRDD16[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13], ct[T14], ct[T15], ct[T16])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C, T14: C, T15: C, T16: C, T17: C](cmd: String): + RDD[Tuple17[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17]] = { + new TableRDD17[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13], ct[T14], ct[T15], ct[T16], ct[T17])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C, T14: C, T15: C, T16: C, T17: C, T18: C](cmd: String): + RDD[Tuple18[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18]] = { + new TableRDD18[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], ct[T18])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C](cmd: String): + RDD[Tuple19[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, + T19]] = { + new TableRDD19[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, + T19](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], ct[T18], ct[T19])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C, T20: C](cmd: String): + RDD[Tuple20[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, + T19, T20]] = { + new TableRDD20[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, + T19, T20](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], ct[T18], ct[T19], ct[T20])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C, T20: C, T21: C](cmd: String): + RDD[Tuple21[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, + T19, T20, T21]] = { + new TableRDD21[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, + T19, T20, T21](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], ct[T18], ct[T19], ct[T20], ct[T21])) + } + + def sqlRdd[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, T11: C, T12: C, + T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C, T20: C, T21: C, T22: C](cmd: String): + RDD[Tuple22[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, + T19, T20, T21, T22]] = { + new TableRDD22[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, + T19, T20, T21, T22](sql2rdd(cmd), + Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], ct[T10], ct[T11], ct[T12], + ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], ct[T18], ct[T19], ct[T20], ct[T21], ct[T22])) + } + /** * Execute a SQL command and collect the results locally. * @@ -113,6 +298,10 @@ class SharkContext( * @return A ResultSet object with both the schema and the query results. */ def runSql(cmd: String, maxRows: Int = 1000): ResultSet = { + if (cmd.trim.toLowerCase().startsWith("generate")) { + return TGF.execute(cmd.trim, this) + } + SparkEnv.set(sparkEnv) val cmd_trimmed: String = cmd.trim() @@ -179,12 +368,6 @@ object SharkContext { @transient val hiveconf = new HiveConf(classOf[SessionState]) Utils.setAwsCredentials(hiveconf) - try { - LogUtils.initHiveLog4j() - } catch { - case e: LogInitializationException => // Ignore the error. - } - @transient val sessionState = new SessionState(hiveconf) sessionState.out = new PrintStream(System.out, true, "UTF-8") sessionState.err = new PrintStream(System.out, true, "UTF-8") @@ -192,5 +375,3 @@ object SharkContext { // A dummy init to make sure the object is properly initialized. def init() {} } - - diff --git a/src/main/scala/shark/SharkDriver.scala b/src/main/scala/shark/SharkDriver.scala index 1c585559..8e0c4d5e 100755 --- a/src/main/scala/shark/SharkDriver.scala +++ b/src/main/scala/shark/SharkDriver.scala @@ -35,9 +35,13 @@ import org.apache.hadoop.util.StringUtils import shark.api.TableRDD import shark.api.QueryExecutionException -import shark.execution.{SharkExplainTask, SharkExplainWork, SparkTask, SparkWork} +import shark.execution.{SharkDDLTask, SharkDDLWork} +import shark.execution.{SharkExplainTask, SharkExplainWork} +import shark.execution.{SparkLoadWork, SparkLoadTask} +import shark.execution.{SparkTask, SparkWork} import shark.memstore2.ColumnarSerDe import shark.parse.{QueryContext, SharkSemanticAnalyzerFactory} +import shark.util.QueryRewriteUtils /** @@ -51,7 +55,7 @@ private[shark] object SharkDriver extends LogHelper { // A dummy static method so we can make sure the following static code are executed. def runStaticCode() { - logInfo("Initializing object SharkDriver") + logDebug("Initializing object SharkDriver") } def registerSerDe(serdeClass: Class[_ <: SerDe]) { @@ -62,6 +66,8 @@ private[shark] object SharkDriver extends LogHelper { // Task factory. Add Shark specific tasks. TaskFactory.taskvec.addAll(Seq( + new TaskFactory.taskTuple(classOf[SharkDDLWork], classOf[SharkDDLTask]), + new TaskFactory.taskTuple(classOf[SparkLoadWork], classOf[SparkLoadTask]), new TaskFactory.taskTuple(classOf[SparkWork], classOf[SparkTask]), new TaskFactory.taskTuple(classOf[SharkExplainWork], classOf[SharkExplainTask]))) @@ -214,9 +220,19 @@ private[shark] class SharkDriver(conf: HiveConf) extends Driver(conf) with LogHe saveSession(queryState) try { - val command = new VariableSubstitution().substitute(conf, _cmd) + val command = { + val varSubbedCmd = new VariableSubstitution().substitute(conf, _cmd).trim + val cmdInUpperCase = varSubbedCmd.toUpperCase + if (cmdInUpperCase.startsWith("CACHE")) { + QueryRewriteUtils.cacheToAlterTable(varSubbedCmd) + } else if (cmdInUpperCase.startsWith("UNCACHE")) { + QueryRewriteUtils.uncacheToAlterTable(varSubbedCmd) + } else { + varSubbedCmd + } + } context = new QueryContext(conf, useTableRddSink) - context.setCmd(_cmd) + context.setCmd(command) context.setTryCount(getTryCount()) val tree = ParseUtils.findRootNonNullToken((new ParseDriver()).parse(command, context)) @@ -236,7 +252,7 @@ private[shark] class SharkDriver(conf: HiveConf) extends Driver(conf) with LogHe sem.analyze(tree, context) } - logInfo("Semantic Analysis Completed") + logDebug("Semantic Analysis Completed") sem.validate() diff --git a/src/main/scala/shark/SharkEnv.scala b/src/main/scala/shark/SharkEnv.scala index 55d646e0..6060ef51 100755 --- a/src/main/scala/shark/SharkEnv.scala +++ b/src/main/scala/shark/SharkEnv.scala @@ -19,56 +19,58 @@ package shark import scala.collection.mutable.{HashMap, HashSet} -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.StatsReportListener -import org.apache.spark.serializer.{KryoSerializer => SparkKryoSerializer} import shark.api.JavaSharkContext -import shark.memstore2.MemoryMetadataManager -import shark.execution.serialization.ShuffleSerializer +import shark.execution.serialization.{KryoSerializer, ShuffleSerializer} +import shark.memstore2.{MemoryMetadataManager, Table} import shark.tachyon.TachyonUtilImpl /** A singleton object for the master program. The slaves should not access this. */ object SharkEnv extends LogHelper { - def init(): SparkContext = { + def init(): SharkContext = { if (sc == null) { - sc = new SparkContext( - if (System.getenv("MASTER") == null) "local" else System.getenv("MASTER"), - "Shark::" + java.net.InetAddress.getLocalHost.getHostName, - System.getenv("SPARK_HOME"), - Nil, - executorEnvVars) - sc.addSparkListener(new StatsReportListener()) + val jobName = "Shark::" + java.net.InetAddress.getLocalHost.getHostName + val master = System.getenv("MASTER") + initWithSharkContext(jobName, master) } sc } - def initWithSharkContext(jobName: String, master: String = System.getenv("MASTER")) + def initWithSharkContext( + jobName: String = "Shark::" + java.net.InetAddress.getLocalHost.getHostName, + master: String = System.getenv("MASTER")) : SharkContext = { if (sc != null) { - sc.stop + sc.stop() } sc = new SharkContext( - if (master == null) "local" else master, - jobName, - System.getenv("SPARK_HOME"), - Nil, - executorEnvVars) + if (master == null) "local" else master, + jobName, + System.getenv("SPARK_HOME"), + Nil, + executorEnvVars) sc.addSparkListener(new StatsReportListener()) - sc.asInstanceOf[SharkContext] + sc + } + + def initWithSharkContext(conf: SparkConf): SharkContext = { + conf.setExecutorEnv(executorEnvVars.toSeq) + initWithSharkContext(new SharkContext(conf)) } def initWithSharkContext(newSc: SharkContext): SharkContext = { if (sc != null) { - sc.stop + sc.stop() } - sc = newSc - sc.asInstanceOf[SharkContext] + sc.addSparkListener(new StatsReportListener()) + sc } def initWithJavaSharkContext(jobName: String): JavaSharkContext = { @@ -83,10 +85,7 @@ object SharkEnv extends LogHelper { new JavaSharkContext(initWithSharkContext(newSc.sharkCtx)) } - logInfo("Initializing SharkEnv") - - System.setProperty("spark.serializer", classOf[SparkKryoSerializer].getName) - System.setProperty("spark.kryo.registrator", classOf[KryoRegistrator].getName) + logDebug("Initializing SharkEnv") val executorEnvVars = new HashMap[String, String] executorEnvVars.put("SCALA_HOME", getEnv("SCALA_HOME")) @@ -98,7 +97,9 @@ object SharkEnv extends LogHelper { executorEnvVars.put("TACHYON_MASTER", getEnv("TACHYON_MASTER")) executorEnvVars.put("TACHYON_WAREHOUSE_PATH", getEnv("TACHYON_WAREHOUSE_PATH")) - var sc: SparkContext = _ + val activeSessions = new HashSet[String] + + var sc: SharkContext = _ val shuffleSerializerName = classOf[ShuffleSerializer].getName @@ -114,21 +115,10 @@ object SharkEnv extends LogHelper { val addedFiles = HashSet[String]() val addedJars = HashSet[String]() - def unpersist(key: String): Option[RDD[_]] = { - if (SharkEnv.tachyonUtil.tachyonEnabled() && SharkEnv.tachyonUtil.tableExists(key)) { - if (SharkEnv.tachyonUtil.dropTable(key)) { - logInfo("Table " + key + " was deleted from Tachyon."); - } else { - logWarning("Failed to remove table " + key + " from Tachyon."); - } - } - - memoryMetadataManager.unpersist(key) - } - /** Cleans up and shuts down the Shark environments. */ def stop() { - logInfo("Shutting down Shark Environment") + logDebug("Shutting down Shark Environment") + memoryMetadataManager.shutdown() // Stop the SparkContext if (SharkEnv.sc != null) { sc.stop() @@ -138,6 +128,7 @@ object SharkEnv extends LogHelper { /** Return the value of an environmental variable as a string. */ def getEnv(varname: String) = if (System.getenv(varname) == null) "" else System.getenv(varname) + } diff --git a/src/main/scala/shark/SharkServer.scala b/src/main/scala/shark/SharkServer.scala index cd076a1d..0f42a539 100644 --- a/src/main/scala/shark/SharkServer.scala +++ b/src/main/scala/shark/SharkServer.scala @@ -21,13 +21,15 @@ import java.io.FileOutputStream import java.io.IOException import java.io.PrintStream import java.io.UnsupportedEncodingException +import java.net.InetSocketAddress import java.util.ArrayList import java.util.{List => JavaList} import java.util.Properties import java.util.concurrent.CountDownLatch import scala.annotation.tailrec -import scala.concurrent.ops.spawn +import scala.concurrent._ +import scala.concurrent.ExecutionContext.Implicits.global import org.apache.commons.logging.LogFactory import org.apache.commons.cli.OptionBuilder @@ -50,9 +52,12 @@ import org.apache.thrift.server.TThreadPoolServer import org.apache.thrift.transport.TServerSocket import org.apache.thrift.transport.TTransport import org.apache.thrift.transport.TTransportFactory +import org.apache.thrift.transport.TSocket import org.apache.spark.SparkEnv +import shark.memstore2.TableRecovery + /** * A long-running server compatible with the Hive server. @@ -70,28 +75,55 @@ object SharkServer extends LogHelper { def main(args: Array[String]) { - val cli = new SharkServerCliOptions - cli.parse(args) + val cliOptions = new SharkServerCliOptions + cliOptions.parse(args) // From Hive: It is critical to do this prior to initializing log4j, otherwise // any log specific settings via hiveconf will be ignored. - val hiveconf: Properties = cli.addHiveconfToSystemProperties() + val hiveconf: Properties = cliOptions.addHiveconfToSystemProperties() // From Hive: It is critical to do this here so that log4j is reinitialized // before any of the other core hive classes are loaded LogUtils.initHiveLog4j() val latch = new CountDownLatch(1) - serverTransport = new TServerSocket(cli.port) + serverTransport = new TServerSocket(cliOptions.port) val hfactory = new ThriftHiveProcessorFactory(null, new HiveConf()) { - override def getProcessor(t: TTransport) = - new ThriftHive.Processor(new GatedSharkServerHandler(latch)) + override def getProcessor(t: TTransport) = { + var remoteClient = "Unknown" + + // Seed session ID by a random number + var sessionID = scala.math.round(scala.math.random * 10000000).toString + var jdbcSocket: java.net.Socket = null + if (t.isInstanceOf[TSocket]) { + remoteClient = t.asInstanceOf[TSocket].getSocket() + .getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + .getAddress().toString() + + jdbcSocket = t.asInstanceOf[TSocket].getSocket() + jdbcSocket.setKeepAlive(true) + sessionID = remoteClient + "/" + jdbcSocket + .getRemoteSocketAddress().asInstanceOf[InetSocketAddress].getPort().toString + + "/" + sessionID + + } + logInfo("Audit Log: Connection Initiated with JDBC client - " + remoteClient) + + // Add and enable watcher thread + // This handles both manual killing of session as well as connection drops + val watcher = new JDBCWatcher(jdbcSocket, sessionID) + SharkEnv.activeSessions.add(sessionID) + watcher.start() + + new ThriftHive.Processor(new GatedSharkServerHandler(latch, remoteClient, + sessionID)) + } } val ttServerArgs = new TThreadPoolServer.Args(serverTransport) .processorFactory(hfactory) - .minWorkerThreads(cli.minWorkerThreads) - .maxWorkerThreads(cli.maxWorkerThreads) + .minWorkerThreads(cliOptions.minWorkerThreads) + .maxWorkerThreads(cliOptions.maxWorkerThreads) .transportFactory(new TTransportFactory()) .protocolFactory(new TBinaryProtocol.Factory()) server = new TThreadPoolServer(ttServerArgs) @@ -110,12 +142,13 @@ object SharkServer extends LogHelper { } ) - // Optionally load the cached tables. - execLoadRdds(cli.loadRdds, latch) + // Optionally reload cached tables from a previous session. + execLoadRdds(cliOptions.reloadRdds, latch) // Start serving. - val startupMsg = "Starting Shark server on port " + cli.port + " with " + cli.minWorkerThreads + - " min worker threads and " + cli.maxWorkerThreads + " max worker threads" + val startupMsg = "Starting Shark server on port " + cliOptions.port + " with " + + cliOptions.minWorkerThreads + " min worker threads and " + cliOptions.maxWorkerThreads + + " max worker threads." logInfo(startupMsg) println(startupMsg) server.serve() @@ -132,12 +165,11 @@ object SharkServer extends LogHelper { private def execLoadRdds(loadFlag: Boolean, latch:CountDownLatch) { if (!loadFlag) { latch.countDown - } else spawn { + } else future { while (!server.isServing()) {} try { val sshandler = new SharkServerHandler - CachedTableRecovery.loadAsRdds(sshandler.execute(_)) - logInfo("Executed load " + CachedTableRecovery.getMeta) + TableRecovery.reloadRdds(sshandler.execute(_)) } catch { case (e: Exception) => logWarning("Unable to load RDDs upon startup", e) } finally { @@ -145,26 +177,66 @@ object SharkServer extends LogHelper { } } } + + // Detecting socket connection drops relies on TCP keep alives + // The approach is very platform specific on the duration and nature of detection + // Since java does not expose any mechanisms for tuning keepalive configurations, + // the users should explore the server OS settings for the same. + class JDBCWatcher(sock:java.net.Socket, sessionID:String) extends Thread { + + override def run() { + try { + while ((sock == null || sock.isConnected) && SharkEnv.activeSessions.contains(sessionID)) { + if (sock != null) + sock.getOutputStream().write((new Array[Byte](0)).toArray) + logDebug("Session Socket Alive - " + sessionID) + Thread.sleep(2*1000) + } + } catch { + case ioe: IOException => Unit + } + + // Session is terminated either manually or automatically + // clean up the jobs associated with the session ID + logInfo("Session Socket connection lost, cleaning up - " + sessionID) + SharkEnv.sc.cancelJobGroup(sessionID) + } + + } // Used to parse command line arguments for the server. class SharkServerCliOptions extends HiveServerCli { - var loadRdds = false + var reloadRdds = false - val OPTION_LOAD_RDDS = "loadRdds" - OPTIONS.addOption(OptionBuilder.create(OPTION_LOAD_RDDS)) + val OPTION_SKIP_RELOAD_RDDS = "skipRddReload" + OPTIONS.addOption(OptionBuilder.create(OPTION_SKIP_RELOAD_RDDS)) override def parse(args: Array[String]) { super.parse(args) - loadRdds = commandLine.hasOption(OPTION_LOAD_RDDS) + reloadRdds = !commandLine.hasOption(OPTION_SKIP_RELOAD_RDDS) } } } -class GatedSharkServerHandler(latch:CountDownLatch) extends SharkServerHandler { +class GatedSharkServerHandler(latch:CountDownLatch, remoteClient:String, + sessionID:String) extends SharkServerHandler { override def execute(cmd: String): Unit = { latch.await - super.execute(cmd) + + logInfo("Audit Log: SessionID=" + sessionID + " client=" + remoteClient + " cmd=" + cmd) + + // Handle cancel commands + if (cmd.startsWith("kill ")) { + logInfo("killing group - " + cmd) + val sessionIDToCancel = cmd.split("\\s+|\\s*;").apply(1) + SharkEnv.activeSessions.remove(sessionIDToCancel) + } else { + // Session ID is used as spark job group + // Job groups control cleanup/cancelling of unneeded jobs on connection terminations + SharkEnv.sc.setJobGroup(sessionID, "Session ID = " + sessionID) + super.execute(cmd) + } } } @@ -271,9 +343,11 @@ class SharkServerHandler extends HiveServerHandler with LogHelper { "" } else { val list: JavaList[String] = fetchN(1) - if (list.isEmpty) + if (list.isEmpty) { "" - else list.get(0) + } else { + list.get(0) + } } } diff --git a/src/main/scala/shark/Utils.scala b/src/main/scala/shark/Utils.scala index 66a42f5c..136d7bcd 100644 --- a/src/main/scala/shark/Utils.scala +++ b/src/main/scala/shark/Utils.scala @@ -19,8 +19,8 @@ package shark import java.io.BufferedReader import java.util.{Map => JMap} - import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} object Utils { @@ -94,4 +94,18 @@ object Utils { new BufferedReader(new InputStreamReader(s3obj.getDataInputStream())) } + /** + * Returns a filter that accepts files not present in the captured snapshot of the `path` + * directory. + */ + def createSnapshotFilter(path: Path, conf: Configuration): PathFilter = { + val fs = path.getFileSystem(conf) + val currentFiles = fs.listStatus(path).map(_.getPath).toSet + val fileFilter = new PathFilter() { + override def accept(path: Path) = { + (!path.getName().startsWith(".") && !currentFiles.contains(path)) + } + } + fileFilter + } } diff --git a/src/main/scala/shark/api/ClassTags.scala b/src/main/scala/shark/api/ClassTags.scala new file mode 100644 index 00000000..b3c17d01 --- /dev/null +++ b/src/main/scala/shark/api/ClassTags.scala @@ -0,0 +1,14 @@ +package shark.api + +import scala.reflect.classTag + +object ClassTags { + // List of primitive ClassTags. + val jBoolean = classTag[java.lang.Boolean] + val jByte = classTag[java.lang.Byte] + val jShort = classTag[java.lang.Short] + val jInt = classTag[java.lang.Integer] + val jLong = classTag[java.lang.Long] + val jFloat = classTag[java.lang.Float] + val jDouble = classTag[java.lang.Double] +} diff --git a/src/main/scala/shark/api/DataTypes.java b/src/main/scala/shark/api/DataTypes.java index f8994c7f..4c5ec3f9 100644 --- a/src/main/scala/shark/api/DataTypes.java +++ b/src/main/scala/shark/api/DataTypes.java @@ -17,10 +17,16 @@ package shark.api; +import java.util.Date; import java.util.HashMap; import java.util.Map; +import java.sql.Timestamp; + +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import org.apache.hadoop.hive.serde.Constants; +import shark.Utils$; /** * List of data types defined in Shark APIs. @@ -30,19 +36,38 @@ public class DataTypes { // This list of types are defined in a Java class for better interoperability with Shark's // Java APIs. // Primitive types: - public static final DataType BOOLEAN = new DataType("boolean", Constants.BOOLEAN_TYPE_NAME, true); - public static final DataType TINYINT = new DataType("tinyint", Constants.TINYINT_TYPE_NAME, true); - public static final DataType SMALLINT = - new DataType("smallint", Constants.SMALLINT_TYPE_NAME, true); - public static final DataType INT = new DataType("int", Constants.INT_TYPE_NAME, true); - public static final DataType BIGINT = new DataType("bigint", Constants.BIGINT_TYPE_NAME, true); - public static final DataType FLOAT = new DataType("float", Constants.FLOAT_TYPE_NAME, true); - public static final DataType DOUBLE = new DataType("double", Constants.DOUBLE_TYPE_NAME, true); - public static final DataType STRING = new DataType("string", Constants.STRING_TYPE_NAME, true); - public static final DataType TIMESTAMP = - new DataType("timestamp", Constants.TIMESTAMP_TYPE_NAME, true); - public static final DataType DATE = new DataType("date", Constants.DATE_TYPE_NAME, true); - public static final DataType BINARY = new DataType("binary", Constants.BINARY_TYPE_NAME, true); + public static final DataType BOOLEAN = new DataType( + "boolean", Constants.BOOLEAN_TYPE_NAME, true); + + public static final DataType TINYINT = new DataType( + "tinyint", Constants.TINYINT_TYPE_NAME, true); + + public static final DataType SMALLINT = new DataType( + "smallint", Constants.SMALLINT_TYPE_NAME, true); + + public static final DataType INT = new DataType( + "int", Constants.INT_TYPE_NAME, true); + + public static final DataType BIGINT = new DataType( + "bigint", Constants.BIGINT_TYPE_NAME, true); + + public static final DataType FLOAT = new DataType( + "float", Constants.FLOAT_TYPE_NAME, true); + + public static final DataType DOUBLE = new DataType( + "double", Constants.DOUBLE_TYPE_NAME, true); + + public static final DataType STRING = new DataType( + "string", Constants.STRING_TYPE_NAME, true); + + public static final DataType TIMESTAMP = new DataType( + "timestamp", Constants.TIMESTAMP_TYPE_NAME, true); + + public static final DataType DATE = new DataType( + "date", Constants.DATE_TYPE_NAME, true); + + public static final DataType BINARY = new DataType( + "binary", Constants.BINARY_TYPE_NAME, true); // Complex types: // TODO: handle complex types. @@ -82,4 +107,31 @@ public static DataType fromHiveType(String hiveType) throws UnknownDataTypeExcep return type; } } + + public static DataType fromClassTag(ClassTag m) throws UnknownDataTypeException { + if (m.equals(ClassTag$.MODULE$.Boolean()) || m.equals(ClassTags$.MODULE$.jBoolean())) { + return INT; + } else if (m.equals(ClassTag$.MODULE$.Byte()) || m.equals(ClassTags$.MODULE$.jByte())){ + return TINYINT; + } else if (m.equals(ClassTag$.MODULE$.Short()) || m.equals(ClassTags$.MODULE$.jShort())) { + return SMALLINT; + } else if (m.equals(ClassTag$.MODULE$.Int()) || m.equals(ClassTags$.MODULE$.jInt())) { + return INT; + } else if (m.equals(ClassTag$.MODULE$.Long()) || m.equals(ClassTags$.MODULE$.jLong())) { + return BIGINT; + } else if (m.equals(ClassTag$.MODULE$.Float()) || m.equals(ClassTags$.MODULE$.jFloat())) { + return FLOAT; + } else if (m.equals(ClassTag$.MODULE$.Double()) || m.equals(ClassTags$.MODULE$.jDouble())) { + return DOUBLE; + } else if (m.equals(ClassTag$.MODULE$.apply(String.class))) { + return STRING; + } else if (m.equals(ClassTag$.MODULE$.apply(Timestamp.class))) { + return TIMESTAMP; + } else if (m.equals(ClassTag$.MODULE$.apply(Date.class))) { + return DATE; + } else { + throw new UnknownDataTypeException(m.toString()); + } + // TODO: binary data type. + } } diff --git a/src/main/scala/shark/api/JavaTableRDD.scala b/src/main/scala/shark/api/JavaTableRDD.scala index de111173..50be2d4f 100644 --- a/src/main/scala/shark/api/JavaTableRDD.scala +++ b/src/main/scala/shark/api/JavaTableRDD.scala @@ -17,6 +17,8 @@ package shark.api +import scala.reflect.ClassTag + import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.api.java.JavaRDDLike import org.apache.spark.rdd.RDD @@ -29,7 +31,7 @@ class JavaTableRDD(val rdd: RDD[Row], val schema: Array[ColumnDesc]) override def wrapRDD(rdd: RDD[Row]): JavaTableRDD = new JavaTableRDD(rdd, schema) // Common RDD functions - override val classManifest: ClassManifest[Row] = implicitly[ClassManifest[Row]] + override val classTag: ClassTag[Row] = implicitly[ClassTag[Row]] // This shouldn't be necessary, but we seem to need this to get first() to return Row // instead of Object; possibly a compiler bug? diff --git a/src/main/scala/shark/api/RDDTable.scala b/src/main/scala/shark/api/RDDTable.scala new file mode 100644 index 00000000..c0496e9e --- /dev/null +++ b/src/main/scala/shark/api/RDDTable.scala @@ -0,0 +1,347 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.api + +// *** This file is auto-generated from RDDTable_generator.py *** +import scala.language.implicitConversions +import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD + +object RDDTableImplicits { + private type C[T] = ClassTag[T] + + + implicit def rddToTable2[T1: C, T2: C] + (rdd: RDD[(T1, T2)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable3[T1: C, T2: C, T3: C] + (rdd: RDD[(T1, T2, T3)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable4[T1: C, T2: C, T3: C, T4: C] + (rdd: RDD[(T1, T2, T3, T4)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable5[T1: C, T2: C, T3: C, T4: C, T5: C] + (rdd: RDD[(T1, T2, T3, T4, T5)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable6[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable7[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable8[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable9[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable10[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable11[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable12[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable13[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable14[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable15[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable16[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable17[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable18[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable19[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable20[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C, + T20: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable21[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C, + T20: C, T21: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21)]): RDDTableFunctions = RDDTable(rdd) + + + implicit def rddToTable22[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C, + T20: C, T21: C, T22: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21, T22)]): RDDTableFunctions = RDDTable(rdd) + + +} + +object RDDTable { + + private type C[T] = ClassTag[T] + private def ct[T](implicit c: ClassTag[T]) = c + + def apply[T1: C, T2: C] + (rdd: RDD[(T1, T2)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2])) + } + + + def apply[T1: C, T2: C, T3: C] + (rdd: RDD[(T1, T2, T3)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C] + (rdd: RDD[(T1, T2, T3, T4)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C] + (rdd: RDD[(T1, T2, T3, T4, T5)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13], ct[T14])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13], ct[T14], ct[T15])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13], ct[T14], ct[T15], ct[T16])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13], ct[T14], ct[T15], ct[T16], ct[T17])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], + ct[T18])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], + ct[T18], ct[T19])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C, + T20: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], + ct[T18], ct[T19], ct[T20])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C, + T20: C, T21: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], + ct[T18], ct[T19], ct[T20], ct[T21])) + } + + + def apply[T1: C, T2: C, T3: C, T4: C, T5: C, T6: C, T7: C, T8: C, T9: C, T10: C, + T11: C, T12: C, T13: C, T14: C, T15: C, T16: C, T17: C, T18: C, T19: C, + T20: C, T21: C, T22: C] + (rdd: RDD[(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21, T22)]) = { + val classTag = implicitly[ClassTag[Seq[Any]]] + val rddSeq: RDD[Seq[_]] = rdd.map(t => t.productIterator.toList.asInstanceOf[Seq[Any]])(classTag) + new RDDTableFunctions(rddSeq, Seq(ct[T1], ct[T2], ct[T3], ct[T4], ct[T5], ct[T6], ct[T7], ct[T8], ct[T9], + ct[T10], ct[T11], ct[T12], ct[T13], ct[T14], ct[T15], ct[T16], ct[T17], + ct[T18], ct[T19], ct[T20], ct[T21], ct[T22])) + } + +} diff --git a/src/main/scala/shark/api/RDDTableFunctions.scala b/src/main/scala/shark/api/RDDTableFunctions.scala new file mode 100644 index 00000000..06f42b4e --- /dev/null +++ b/src/main/scala/shark/api/RDDTableFunctions.scala @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.api + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.hadoop.hive.ql.metadata.Hive + +import org.apache.spark.rdd.RDD + +import shark.{SharkContext, SharkEnv} +import shark.memstore2.{CacheType, TablePartitionStats, TablePartition, TablePartitionBuilder} +import shark.util.HiveUtils + + +class RDDTableFunctions(self: RDD[Seq[_]], classTags: Seq[ClassTag[_]]) { + + def saveAsTable(tableName: String, fields: Seq[String]): Boolean = { + require(fields.size == this.classTags.size, + "Number of column names != number of fields in the RDD.") + + // Get a local copy of the classTags so we don't need to serialize this object. + val classTags = this.classTags + + val statsAcc = SharkEnv.sc.accumulableCollection(ArrayBuffer[(Int, TablePartitionStats)]()) + + // Create the RDD object. + val rdd = self.mapPartitionsWithIndex { case(partitionIndex, iter) => + val ois = classTags.map(HiveUtils.getJavaPrimitiveObjectInspector) + val builder = new TablePartitionBuilder(ois, 1000000, shouldCompress = false) + + for (p <- iter) { + builder.incrementRowCount() + // TODO: this is not the most efficient code to do the insertion ... + p.zipWithIndex.foreach { case (v, i) => + builder.append(i, v.asInstanceOf[Object], ois(i)) + } + } + + statsAcc += Tuple2(partitionIndex, builder.asInstanceOf[TablePartitionBuilder].stats) + Iterator(builder.build()) + }.persist() + + var isSucessfulCreateTable = HiveUtils.createTableInHive( + tableName, fields, classTags, Hive.get().getConf()) + + // Put the table in the metastore. Only proceed if the DDL statement is executed successfully. + val databaseName = Hive.get(SharkContext.hiveconf).getCurrentDatabase() + if (isSucessfulCreateTable) { + // Create an entry in the MemoryMetadataManager. + val newTable = SharkEnv.memoryMetadataManager.createMemoryTable( + databaseName, tableName, CacheType.MEMORY) + try { + // Force evaluate to put the data in memory. + rdd.context.runJob(rdd, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit)) + } catch { + case _: Exception => { + // Intercept the exception thrown by SparkContext#runJob() and handle it silently. The + // exception message should already be printed to the console by DDLTask#execute(). + HiveUtils.dropTableInHive(tableName) + // Drop the table entry from MemoryMetadataManager. + SharkEnv.memoryMetadataManager.removeTable(databaseName, tableName) + isSucessfulCreateTable = false + } + } + newTable.put(rdd, statsAcc.value.toMap) + } + return isSucessfulCreateTable + } +} diff --git a/src/main/scala/shark/api/ResultSet.scala b/src/main/scala/shark/api/ResultSet.scala index abd49c99..682d769b 100644 --- a/src/main/scala/shark/api/ResultSet.scala +++ b/src/main/scala/shark/api/ResultSet.scala @@ -1,3 +1,20 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.api import java.util.{Arrays, Collections, List => JList} @@ -30,4 +47,4 @@ class ResultSet private[shark](_schema: Array[ColumnDesc], _results: Array[Array _results.map(row => row.mkString("\t")).mkString("\n") } -} \ No newline at end of file +} diff --git a/src/main/scala/shark/api/Row.scala b/src/main/scala/shark/api/Row.scala index b6f2224a..f91cd05d 100644 --- a/src/main/scala/shark/api/Row.scala +++ b/src/main/scala/shark/api/Row.scala @@ -17,8 +17,10 @@ package shark.api +import org.apache.hadoop.io.Text +import org.apache.hadoop.hive.serde2.ByteStream +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ /** @@ -31,14 +33,24 @@ class Row(val rawdata: Any, val colname2indexMap: Map[String, Int], val oi: Stru def apply(field: String): Object = apply(colname2indexMap(field)) def apply(field: Int): Object = { - val ref = oi.getAllStructFieldRefs.get(field) - val data = oi.getStructFieldData(rawdata, ref) - + val ref: StructField = oi.getAllStructFieldRefs.get(field) + val data: Object = oi.getStructFieldData(rawdata, ref) ref.getFieldObjectInspector match { case poi: PrimitiveObjectInspector => poi.getPrimitiveJavaObject(data) - case loi: ListObjectInspector => loi.getList(data) - case moi: MapObjectInspector => moi.getMap(data) - case soi: StructObjectInspector => soi.getStructFieldsDataAsList(data) + case _: ListObjectInspector | _: MapObjectInspector | _: StructObjectInspector => + // For complex types, return the string representation of data. + val stream = new ByteStream.Output() + LazySimpleSerDe.serialize( + stream, // out + data, // obj + ref.getFieldObjectInspector, // objInspector + Array[Byte](1, 2, 3, 4, 5, 6, 7, 8), // separators + 1, // level + new Text(""), // nullSequence + true, // escaped + 92, // escapeChar + Row.needsEscape) // needsEscape + stream.toString } } @@ -98,34 +110,39 @@ class Row(val rawdata: Any, val colname2indexMap: Map[String, Int], val oi: Stru ref.getFieldObjectInspector.asInstanceOf[PrimitiveObjectInspector].getPrimitiveJavaObject(data) } + def getPrimitiveGeneric[T](field: Int): T = getPrimitive(field).asInstanceOf[T] + + def getPrimitiveGeneric[T](field: String): T = getPrimitiveGeneric[T](colname2indexMap(field)) + ///////////////////////////////////////////////////////////////////////////////////////////////// - // Complex data types - // rxin: I am not sure how useful these APIs are since they would expose the Hive internal - // data structure. For example, in the case of an array of strings, getList would actually - // return a List of LazyString. + // Complex data types - only return the string representation of them for now. ///////////////////////////////////////////////////////////////////////////////////////////////// - def getList(field: String): java.util.List[_] = getList(colname2indexMap(field)) + def getList(field: String): String = getList(colname2indexMap(field)) - def getMap(field: String): java.util.Map[_, _] = getMap(colname2indexMap(field)) + def getMap(field: String): String = getMap(colname2indexMap(field)) - def getStruct(field: String): java.util.List[Object] = getStruct(colname2indexMap(field)) + def getStruct(field: String): String = getStruct(colname2indexMap(field)) - def getList(field: Int): java.util.List[_] = { - val ref = oi.getAllStructFieldRefs.get(field) - val data = oi.getStructFieldData(rawdata, ref) - ref.getFieldObjectInspector.asInstanceOf[ListObjectInspector].getList(data) - } + def getList(field: Int): String = apply(field).asInstanceOf[String] - def getMap(field: Int): java.util.Map[_, _] = { - val ref = oi.getAllStructFieldRefs.get(field) - val data = oi.getStructFieldData(rawdata, ref) - ref.getFieldObjectInspector.asInstanceOf[MapObjectInspector].getMap(data) - } + def getMap(field: Int): String = apply(field).asInstanceOf[String] - def getStruct(field: Int): java.util.List[Object] = { - val ref = oi.getAllStructFieldRefs.get(field) - val data = oi.getStructFieldData(rawdata, ref) - ref.getFieldObjectInspector.asInstanceOf[StructObjectInspector].getStructFieldsDataAsList(data) - } + def getStruct(field: Int): String = apply(field).asInstanceOf[String] +} + + +private[shark] object Row { + // For Hive's LazySimpleSerDe + val needsEscape = Array[Boolean]( + false, true, true, true, true, true, true, true, true, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, true, + false, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false) } diff --git a/src/main/scala/shark/api/TableRDD.scala b/src/main/scala/shark/api/TableRDD.scala index 75929f80..89122a18 100644 --- a/src/main/scala/shark/api/TableRDD.scala +++ b/src/main/scala/shark/api/TableRDD.scala @@ -75,6 +75,3 @@ class TableRDD( } } } - - - diff --git a/src/main/scala/shark/api/TableRDDGenerated.scala b/src/main/scala/shark/api/TableRDDGenerated.scala new file mode 100644 index 00000000..a0189831 --- /dev/null +++ b/src/main/scala/shark/api/TableRDDGenerated.scala @@ -0,0 +1,649 @@ + +/* + * Copyright (C) 2013 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +package shark.api + +// *** This file is auto-generated from TableRDDGenerated_generator.py *** +import scala.language.implicitConversions +import org.apache.spark.rdd.RDD +import org.apache.spark.{TaskContext, Partition} + +import scala.reflect.ClassTag + +class TableSeqRDD(prev: TableRDD) + extends RDD[Seq[Any]](prev) { + + def getSchema = prev.schema + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): Iterator[Seq[Any]] = { + prev.compute(split, context).map( row => + (0 until prev.schema.size).map(i => row.getPrimitive(i)) ) + } +} + + +class TableRDD1[T1](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple1[T1]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 1, "Table only has " + tableCols + " columns, expecting 1") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple1[T1]] = { + prev.compute(split, context).map( row => + new Tuple1[T1]( + row.getPrimitiveGeneric[T1](0) ) ) + + } +} + +class TableRDD2[T1, T2](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple2[T1, T2]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 2, "Table only has " + tableCols + " columns, expecting 2") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple2[T1, T2]] = { + prev.compute(split, context).map( row => + new Tuple2[T1, T2]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1) ) ) + + } +} + +class TableRDD3[T1, T2, T3](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple3[T1, T2, T3]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 3, "Table only has " + tableCols + " columns, expecting 3") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple3[T1, T2, T3]] = { + prev.compute(split, context).map( row => + new Tuple3[T1, T2, T3]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2) + ) ) + + } +} + +class TableRDD4[T1, T2, T3, T4](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple4[T1, T2, T3, T4]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 4, "Table only has " + tableCols + " columns, expecting 4") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple4[T1, T2, T3, T4]] = { + prev.compute(split, context).map( row => + new Tuple4[T1, T2, T3, T4]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3) ) ) + + } +} + +class TableRDD5[T1, T2, T3, T4, T5](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple5[T1, T2, T3, T4, T5]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 5, "Table only has " + tableCols + " columns, expecting 5") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple5[T1, T2, T3, T4, T5]] = { + prev.compute(split, context).map( row => + new Tuple5[T1, T2, T3, T4, T5]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4) ) ) + + } +} + +class TableRDD6[T1, T2, T3, T4, T5, T6](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple6[T1, T2, T3, T4, T5, T6]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 6, "Table only has " + tableCols + " columns, expecting 6") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple6[T1, T2, T3, T4, T5, T6]] = { + prev.compute(split, context).map( row => + new Tuple6[T1, T2, T3, T4, T5, T6]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5) + ) ) + + } +} + +class TableRDD7[T1, T2, T3, T4, T5, T6, T7](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple7[T1, T2, T3, T4, T5, T6, T7]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 7, "Table only has " + tableCols + " columns, expecting 7") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple7[T1, T2, T3, T4, T5, T6, T7]] = { + prev.compute(split, context).map( row => + new Tuple7[T1, T2, T3, T4, T5, T6, T7]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6) ) ) + + } +} + +class TableRDD8[T1, T2, T3, T4, T5, T6, T7, T8](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple8[T1, T2, T3, T4, T5, T6, T7, T8]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 8, "Table only has " + tableCols + " columns, expecting 8") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple8[T1, T2, T3, T4, T5, T6, T7, T8]] = { + prev.compute(split, context).map( row => + new Tuple8[T1, T2, T3, T4, T5, T6, T7, T8]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7) ) ) + + } +} + +class TableRDD9[T1, T2, T3, T4, T5, T6, T7, T8, T9](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple9[T1, T2, T3, T4, T5, T6, T7, T8, T9]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 9, "Table only has " + tableCols + " columns, expecting 9") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple9[T1, T2, T3, T4, T5, T6, T7, T8, T9]] = { + prev.compute(split, context).map( row => + new Tuple9[T1, T2, T3, T4, T5, T6, T7, T8, T9]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8) + ) ) + + } +} + +class TableRDD10[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple10[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 10, "Table only has " + tableCols + " columns, expecting 10") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple10[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]] = { + prev.compute(split, context).map( row => + new Tuple10[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9) ) ) + + } +} + +class TableRDD11[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple11[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 11, "Table only has " + tableCols + " columns, expecting 11") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple11[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]] = { + prev.compute(split, context).map( row => + new Tuple11[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10) ) ) + + } +} + +class TableRDD12[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple12[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 12, "Table only has " + tableCols + " columns, expecting 12") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple12[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]] = { + prev.compute(split, context).map( row => + new Tuple12[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11) + ) ) + + } +} + +class TableRDD13[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple13[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 13, "Table only has " + tableCols + " columns, expecting 13") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple13[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]] = { + prev.compute(split, context).map( row => + new Tuple13[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12) ) ) + + } +} + +class TableRDD14[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple14[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 14, "Table only has " + tableCols + " columns, expecting 14") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple14[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]] = { + prev.compute(split, context).map( row => + new Tuple14[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12), row.getPrimitiveGeneric[T14](13) ) ) + + } +} + +class TableRDD15[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple15[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 15, "Table only has " + tableCols + " columns, expecting 15") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple15[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]] = { + prev.compute(split, context).map( row => + new Tuple15[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12), row.getPrimitiveGeneric[T14](13), row.getPrimitiveGeneric[T15](14) + ) ) + + } +} + +class TableRDD16[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple16[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 16, "Table only has " + tableCols + " columns, expecting 16") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple16[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16]] = { + prev.compute(split, context).map( row => + new Tuple16[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12), row.getPrimitiveGeneric[T14](13), row.getPrimitiveGeneric[T15](14), + row.getPrimitiveGeneric[T16](15) ) ) + + } +} + +class TableRDD17[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple17[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 17, "Table only has " + tableCols + " columns, expecting 17") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple17[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17]] = { + prev.compute(split, context).map( row => + new Tuple17[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12), row.getPrimitiveGeneric[T14](13), row.getPrimitiveGeneric[T15](14), + row.getPrimitiveGeneric[T16](15), row.getPrimitiveGeneric[T17](16) ) ) + + } +} + +class TableRDD18[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple18[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 18, "Table only has " + tableCols + " columns, expecting 18") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple18[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18]] = { + prev.compute(split, context).map( row => + new Tuple18[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12), row.getPrimitiveGeneric[T14](13), row.getPrimitiveGeneric[T15](14), + row.getPrimitiveGeneric[T16](15), row.getPrimitiveGeneric[T17](16), row.getPrimitiveGeneric[T18](17) + ) ) + + } +} + +class TableRDD19[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple19[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 19, "Table only has " + tableCols + " columns, expecting 19") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple19[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19]] = { + prev.compute(split, context).map( row => + new Tuple19[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12), row.getPrimitiveGeneric[T14](13), row.getPrimitiveGeneric[T15](14), + row.getPrimitiveGeneric[T16](15), row.getPrimitiveGeneric[T17](16), row.getPrimitiveGeneric[T18](17), + row.getPrimitiveGeneric[T19](18) ) ) + + } +} + +class TableRDD20[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple20[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 20, "Table only has " + tableCols + " columns, expecting 20") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple20[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20]] = { + prev.compute(split, context).map( row => + new Tuple20[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12), row.getPrimitiveGeneric[T14](13), row.getPrimitiveGeneric[T15](14), + row.getPrimitiveGeneric[T16](15), row.getPrimitiveGeneric[T17](16), row.getPrimitiveGeneric[T18](17), + row.getPrimitiveGeneric[T19](18), row.getPrimitiveGeneric[T20](19) ) ) + + } +} + +class TableRDD21[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple21[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 21, "Table only has " + tableCols + " columns, expecting 21") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple21[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21]] = { + prev.compute(split, context).map( row => + new Tuple21[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12), row.getPrimitiveGeneric[T14](13), row.getPrimitiveGeneric[T15](14), + row.getPrimitiveGeneric[T16](15), row.getPrimitiveGeneric[T17](16), row.getPrimitiveGeneric[T18](17), + row.getPrimitiveGeneric[T19](18), row.getPrimitiveGeneric[T20](19), row.getPrimitiveGeneric[T21](20) + ) ) + + } +} + +class TableRDD22[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21, T22](prev: TableRDD, + tags: Seq[ClassTag[_]]) + extends RDD[Tuple22[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21, T22]](prev) { + def schema = prev.schema + + private val tableCols = schema.size + require(tableCols == 22, "Table only has " + tableCols + " columns, expecting 22") + + tags.zipWithIndex.foreach{ case (m, i) => if (DataTypes.fromClassTag(m) != schema(i).dataType) + throw new IllegalArgumentException( + "Type mismatch on column " + (i + 1) + ", expected " + DataTypes.fromClassTag(m) + " got " + schema(i).dataType) } + + override def getPartitions = prev.getPartitions + + override def compute(split: Partition, context: TaskContext): + Iterator[Tuple22[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21, T22]] = { + prev.compute(split, context).map( row => + new Tuple22[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, + T17, T18, T19, T20, T21, T22]( + row.getPrimitiveGeneric[T1](0), row.getPrimitiveGeneric[T2](1), row.getPrimitiveGeneric[T3](2), + row.getPrimitiveGeneric[T4](3), row.getPrimitiveGeneric[T5](4), row.getPrimitiveGeneric[T6](5), + row.getPrimitiveGeneric[T7](6), row.getPrimitiveGeneric[T8](7), row.getPrimitiveGeneric[T9](8), + row.getPrimitiveGeneric[T10](9), row.getPrimitiveGeneric[T11](10), row.getPrimitiveGeneric[T12](11), + row.getPrimitiveGeneric[T13](12), row.getPrimitiveGeneric[T14](13), row.getPrimitiveGeneric[T15](14), + row.getPrimitiveGeneric[T16](15), row.getPrimitiveGeneric[T17](16), row.getPrimitiveGeneric[T18](17), + row.getPrimitiveGeneric[T19](18), row.getPrimitiveGeneric[T20](19), row.getPrimitiveGeneric[T21](20), + row.getPrimitiveGeneric[T22](21) ) ) + + } +} diff --git a/src/main/scala/shark/execution/CoGroupedRDD.scala b/src/main/scala/shark/execution/CoGroupedRDD.scala index e803db90..e5806e34 100644 --- a/src/main/scala/shark/execution/CoGroupedRDD.scala +++ b/src/main/scala/shark/execution/CoGroupedRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark +import scala.language.existentials + import java.io.{ObjectOutputStream, IOException} import java.util.{HashMap => JHashMap} @@ -49,12 +51,15 @@ case class NarrowCoGroupSplitDep( case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep +// equals not implemented style error +// scalastyle:off class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx } +// scalastyle:on class CoGroupAggregator extends Aggregator[Any, Any, ArrayBuffer[Any]]( @@ -72,10 +77,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) override def getDependencies: Seq[Dependency[_]] = { rdds.map { rdd => if (rdd.partitioner == Some(part)) { - logInfo("Adding one-to-one dependency with " + rdd) + logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { - logInfo("Adding shuffle dependency with " + rdd) + logDebug("Adding shuffle dependency with " + rdd) new ShuffleDependency[Any, Any](rdd, part, SharkEnv.shuffleSerializerName) } } @@ -112,7 +117,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } values } - val serializer = SparkEnv.get.serializerManager.get(SharkEnv.shuffleSerializerName) + val serializer = SparkEnv.get.serializerManager.get(SharkEnv.shuffleSerializerName, SparkEnv.get.conf) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => { // Read them from the parent @@ -122,11 +127,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) // Read map outputs of shuffle def mergePair(pair: (K, Any)) { getSeq(pair._1)(depNum) += pair._2 } val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[(K, Seq[Any])](shuffleId, split.index, context.taskMetrics, serializer) + fetcher.fetch[(K, Seq[Any])](shuffleId, split.index, context, serializer) .foreach(mergePair) } } - map.iterator + new InterruptibleIterator(context, map.iterator) } override def clearDependencies() { diff --git a/src/main/scala/shark/execution/CommonJoinOperator.scala b/src/main/scala/shark/execution/CommonJoinOperator.scala index a081ade5..da258864 100755 --- a/src/main/scala/shark/execution/CommonJoinOperator.scala +++ b/src/main/scala/shark/execution/CommonJoinOperator.scala @@ -17,30 +17,23 @@ package shark.execution -import java.util.{HashMap => JavaHashMap, List => JavaList} +import java.util.{HashMap => JavaHashMap, List => JavaList, ArrayList =>JavaArrayList} -import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ -import scala.reflect.BeanProperty +import scala.beans.BeanProperty +import scala.reflect.ClassTag -import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator -import org.apache.hadoop.hive.ql.exec.{CommonJoinOperator => HiveCommonJoinOperator} import org.apache.hadoop.hive.ql.exec.{JoinUtil => HiveJoinUtil} -import org.apache.hadoop.hive.ql.plan.{ExprNodeDesc, JoinCondDesc, JoinDesc, TableDesc} -import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.ql.plan.{JoinCondDesc, JoinDesc} import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, PrimitiveObjectInspector} - -import org.apache.spark.rdd.{RDD, UnionRDD} -import org.apache.spark.SparkContext.rddToPairRDDFunctions +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory import shark.SharkConfVars -abstract class CommonJoinOperator[JOINDESCTYPE <: JoinDesc, T <: HiveCommonJoinOperator[JOINDESCTYPE]] - extends NaryOperator[T] { +abstract class CommonJoinOperator[T <: JoinDesc] extends NaryOperator[T] { - @BeanProperty var conf: JOINDESCTYPE = _ + @BeanProperty var conf: T = _ // Order in which the results should be output. @BeanProperty var order: Array[java.lang.Byte] = _ // condn determines join property (left, right, outer joins). @@ -62,8 +55,11 @@ abstract class CommonJoinOperator[JOINDESCTYPE <: JoinDesc, T <: HiveCommonJoinO @transient var noOuterJoin: Boolean = _ override def initializeOnMaster() { - conf = hiveOp.getConf() - + super.initializeOnMaster() + conf = desc + // TODO currently remove the join filter + conf.getFilters().clear() + order = conf.getTagOrder() joinConditions = conf.getConds() numTables = parentOperators.size @@ -91,10 +87,23 @@ abstract class CommonJoinOperator[JOINDESCTYPE <: JoinDesc, T <: HiveCommonJoinO joinValuesStandardObjectInspectors = HiveJoinUtil.getStandardObjectInspectors( joinValuesObjectInspectors, CommonJoinOperator.NOTSKIPBIGTABLE) } + + // copied from the org.apache.hadoop.hive.ql.exec.CommonJoinOperator + override def outputObjectInspector() = { + var structFieldObjectInspectors = new JavaArrayList[ObjectInspector]() + for (alias <- order) { + var oiList = joinValuesStandardObjectInspectors.get(alias) + structFieldObjectInspectors.addAll(oiList) + } + + ObjectInspectorFactory.getStandardStructObjectInspector( + conf.getOutputColumnNames(), + structFieldObjectInspectors) + } } -class CartesianProduct[T >: Null : ClassManifest](val numTables: Int) { +class CartesianProduct[T >: Null : ClassTag](val numTables: Int) { val SINGLE_NULL_LIST = Seq[T](null) val EMPTY_LIST = Seq[T]() @@ -202,6 +211,9 @@ object CommonJoinOperator { */ def isFiltered(row: Any, filters: JavaList[ExprNodeEvaluator], ois: JavaList[ObjectInspector]) : Boolean = { + // if no filter, then will not be filtered + if (filters == null || ois == null) return false + var ret: java.lang.Boolean = false var j = 0 while (j < filters.size) { @@ -209,7 +221,7 @@ object CommonJoinOperator { ret = ois.get(j).asInstanceOf[PrimitiveObjectInspector].getPrimitiveJavaObject( condition).asInstanceOf[java.lang.Boolean] if (ret == null || !ret) { - return true; + return true } j += 1 } diff --git a/src/main/scala/shark/execution/EmptyRDD.scala b/src/main/scala/shark/execution/EmptyRDD.scala deleted file mode 100644 index 534a34bc..00000000 --- a/src/main/scala/shark/execution/EmptyRDD.scala +++ /dev/null @@ -1,18 +0,0 @@ -package shark.execution - -import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} -import org.apache.spark.rdd.RDD - -/** - * An RDD that is empty, i.e. has no element in it. - * - * TODO: Remove this once EmptyRDD is in Spark. - */ -class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) { - - override def getPartitions: Array[Partition] = Array.empty - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - throw new UnsupportedOperationException("empty RDD") - } -} diff --git a/src/main/scala/shark/execution/ExtractOperator.scala b/src/main/scala/shark/execution/ExtractOperator.scala index 1f64a3e4..767f6573 100755 --- a/src/main/scala/shark/execution/ExtractOperator.scala +++ b/src/main/scala/shark/execution/ExtractOperator.scala @@ -20,8 +20,8 @@ package shark.execution import scala.reflect.BeanProperty import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector import org.apache.hadoop.hive.ql.exec.{ExprNodeEvaluator, ExprNodeEvaluatorFactory} -import org.apache.hadoop.hive.ql.exec.{ExtractOperator => HiveExtractOperator} import org.apache.hadoop.hive.ql.plan.{ExtractDesc, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.io.BytesWritable @@ -29,7 +29,8 @@ import org.apache.hadoop.io.BytesWritable import org.apache.spark.rdd.RDD -class ExtractOperator extends UnaryOperator[HiveExtractOperator] with HiveTopOperator { +class ExtractOperator extends UnaryOperator[ExtractDesc] + with ReduceSinkTableDesc { @BeanProperty var conf: ExtractDesc = _ @BeanProperty var valueTableDesc: TableDesc = _ @@ -39,18 +40,28 @@ class ExtractOperator extends UnaryOperator[HiveExtractOperator] with HiveTopOpe @transient var valueDeser: Deserializer = _ override def initializeOnMaster() { - conf = hiveOp.getConf() + super.initializeOnMaster() + + conf = desc localHconf = super.hconf - valueTableDesc = keyValueTableDescs.values.head._2 + valueTableDesc = keyValueDescs().head._2._2 } override def initializeOnSlave() { + super.initializeOnSlave() + eval = ExprNodeEvaluatorFactory.get(conf.getCol) eval.initialize(objectInspector) valueDeser = valueTableDesc.getDeserializerClass().newInstance() valueDeser.initialize(localHconf, valueTableDesc.getProperties()) } + override def outputObjectInspector() = { + var soi = objectInspectors(0).asInstanceOf[StructObjectInspector] + // take the value part + soi.getAllStructFieldRefs().get(1).getFieldObjectInspector() + } + override def preprocessRdd(rdd: RDD[_]): RDD[_] = { // TODO: hasOrder and limit should really be made by optimizer. val hasOrder = parentOperator match { diff --git a/src/main/scala/shark/execution/FileSinkOperator.scala b/src/main/scala/shark/execution/FileSinkOperator.scala index cfd93640..d26142fb 100644 --- a/src/main/scala/shark/execution/FileSinkOperator.scala +++ b/src/main/scala/shark/execution/FileSinkOperator.scala @@ -17,6 +17,9 @@ package shark.execution +import java.text.SimpleDateFormat +import java.util.Date + import scala.reflect.BeanProperty import org.apache.hadoop.fs.FileSystem @@ -25,9 +28,7 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.exec.{FileSinkOperator => HiveFileSinkOperator} import org.apache.hadoop.hive.ql.exec.JobCloseFeedBack import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.mapred.TaskID -import org.apache.hadoop.mapred.TaskAttemptID -import org.apache.hadoop.mapred.SparkHadoopWriter +import org.apache.hadoop.mapred.{JobID, TaskAttemptID, TaskID} import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD @@ -51,9 +52,15 @@ class FileSinkOperator extends TerminalOperator with Serializable { } def setConfParams(conf: HiveConf, context: TaskContext) { + def createJobID(time: Date, id: Int): JobID = { + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + val jobtrackerID = formatter.format(new Date()) + return new JobID(jobtrackerID, id) + } + val jobID = context.stageId - val splitID = context.splitId - val jID = SparkHadoopWriter.createJobID(now, jobID) + val splitID = context.partitionId + val jID = createJobID(now, jobID) val taID = new TaskAttemptID(new TaskID(jID, true, splitID), 0) conf.set("mapred.job.id", jID.toString) conf.set("mapred.tip.id", taID.getTaskID.toString) @@ -73,6 +80,7 @@ class FileSinkOperator extends TerminalOperator with Serializable { iter.foreach { row => numRows += 1 + // Process and writes each row to a temp file. localHiveOp.processOp(row, 0) } @@ -112,7 +120,7 @@ class FileSinkOperator extends TerminalOperator with Serializable { } } - localHiveOp.closeOp(false) + localHiveOp.closeOp(false /* abort */) Iterator(numRows) } @@ -143,37 +151,55 @@ class FileSinkOperator extends TerminalOperator with Serializable { parentOperators.head match { case op: LimitOperator => - // If there is a limit operator, let's only run one partition at a time to avoid - // launching too many tasks. + // If there is a limit operator, let's run two partitions first. Once we finished running + // the first two partitions, we use that to estimate how many more partitions we need to + // run to satisfy the limit. + val limit = op.limit - val numPartitions = rdd.partitions.length - var totalRows = 0 - var nextPartition = 0 - while (totalRows < limit && nextPartition < numPartitions) { - // Run one partition and get back the number of rows processed there. - totalRows += rdd.context.runJob( + val totalParts = rdd.partitions.length + var rowsFetched = 0L + var partsFetched = 0 + while (rowsFetched < limit && partsFetched < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 2 + if (partsFetched > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (rowsFetched == 0) { + numPartsToTry = totalParts - 2 + } else { + numPartsToTry = (1.5 * limit * partsFetched / rowsFetched).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + + rowsFetched += rdd.context.runJob( rdd, FileSinkOperator.executeProcessFileSinkPartition(this), - Seq(nextPartition), + partsFetched until math.min(partsFetched + numPartsToTry, totalParts), allowLocal = false).sum - nextPartition += 1 + partsFetched += numPartsToTry } case _ => - val rows = rdd.context.runJob(rdd, FileSinkOperator.executeProcessFileSinkPartition(this)) - logInfo("Total number of rows written: " + rows.sum) + val rows: Array[Long] = rdd.context.runJob( + rdd, FileSinkOperator.executeProcessFileSinkPartition(this)) + logDebug("Total number of rows written: " + rows.sum) } - hiveOp.jobClose(localHconf, true, new JobCloseFeedBack) + localHiveOp.jobClose(localHconf, true /* success */, new JobCloseFeedBack) rdd } } object FileSinkOperator { + // Write each partition's output to HDFS, and return the number of rows written. def executeProcessFileSinkPartition(operator: FileSinkOperator) = { val op = OperatorSerializationWrapper(operator) - def writeFiles(context: TaskContext, iter: Iterator[_]): Int = { + def writeFiles(context: TaskContext, iter: Iterator[_]): Long = { op.logDebug("Started executing mapPartitions for operator: " + op) op.logDebug("Input object inspectors: " + op.objectInspectors) diff --git a/src/main/scala/shark/execution/FilterOperator.scala b/src/main/scala/shark/execution/FilterOperator.scala index 2548bb59..7dda42d0 100755 --- a/src/main/scala/shark/execution/FilterOperator.scala +++ b/src/main/scala/shark/execution/FilterOperator.scala @@ -21,13 +21,12 @@ import scala.collection.Iterator import scala.reflect.BeanProperty import org.apache.hadoop.hive.ql.exec.{ExprNodeEvaluator, ExprNodeEvaluatorFactory} -import org.apache.hadoop.hive.ql.exec.{FilterOperator => HiveFilterOperator} import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.plan.FilterDesc import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector -class FilterOperator extends UnaryOperator[HiveFilterOperator] { +class FilterOperator extends UnaryOperator[FilterDesc] { @transient var conditionEvaluator: ExprNodeEvaluator = _ @transient var conditionInspector: PrimitiveObjectInspector = _ @@ -35,7 +34,9 @@ class FilterOperator extends UnaryOperator[HiveFilterOperator] { @BeanProperty var conf: FilterDesc = _ override def initializeOnMaster() { - conf = hiveOp.getConf() + super.initializeOnMaster() + + conf = desc } override def initializeOnSlave() { @@ -56,4 +57,4 @@ class FilterOperator extends UnaryOperator[HiveFilterOperator] { } } -} \ No newline at end of file +} diff --git a/src/main/scala/shark/execution/ForwardOperator.scala b/src/main/scala/shark/execution/ForwardOperator.scala index a4b488ae..93e1ab4d 100755 --- a/src/main/scala/shark/execution/ForwardOperator.scala +++ b/src/main/scala/shark/execution/ForwardOperator.scala @@ -17,12 +17,11 @@ package shark.execution -import org.apache.hadoop.hive.ql.exec.{ForwardOperator => HiveForwardOperator} - import org.apache.spark.rdd.RDD +import org.apache.hadoop.hive.ql.plan.ForwardDesc -class ForwardOperator extends UnaryOperator[HiveForwardOperator] { +class ForwardOperator extends UnaryOperator[ForwardDesc] { override def execute(): RDD[_] = executeParents().head._2 diff --git a/src/main/scala/shark/execution/GroupByOperator.scala b/src/main/scala/shark/execution/GroupByOperator.scala index 8e9fe517..db65b990 100755 --- a/src/main/scala/shark/execution/GroupByOperator.scala +++ b/src/main/scala/shark/execution/GroupByOperator.scala @@ -19,7 +19,6 @@ package shark.execution import org.apache.hadoop.hive.ql.exec.{GroupByOperator => HiveGroupByOperator} import org.apache.hadoop.hive.ql.exec.{ReduceSinkOperator => HiveReduceSinkOperator} -import org.apache.hadoop.hive.ql.plan.GroupByDesc /** diff --git a/src/main/scala/shark/execution/GroupByPostShuffleOperator.scala b/src/main/scala/shark/execution/GroupByPostShuffleOperator.scala index 737f5b95..bc3cfb92 100755 --- a/src/main/scala/shark/execution/GroupByPostShuffleOperator.scala +++ b/src/main/scala/shark/execution/GroupByPostShuffleOperator.scala @@ -44,7 +44,8 @@ import shark.execution.serialization.OperatorSerializationWrapper // The final phase of group by. // TODO(rxin): For multiple distinct aggregations, use sort-based shuffle. -class GroupByPostShuffleOperator extends GroupByPreShuffleOperator with HiveTopOperator { +class GroupByPostShuffleOperator extends GroupByPreShuffleOperator + with ReduceSinkTableDesc { @BeanProperty var keyTableDesc: TableDesc = _ @BeanProperty var valueTableDesc: TableDesc = _ @@ -64,16 +65,9 @@ class GroupByPostShuffleOperator extends GroupByPreShuffleOperator with HiveTopO @transient val distinctHashSets = new JHashMap[Int, JArrayList[JHashSet[KeyWrapper]]]() @transient var unionExprEvaluator: ExprNodeEvaluator = _ - override def initializeOnMaster() { - super.initializeOnMaster() - keyTableDesc = keyValueTableDescs.values.head._1 - valueTableDesc = keyValueTableDescs.values.head._2 - initializeOnSlave() - } - - override def initializeOnSlave() { + override def createLocals() { + super.createLocals() - super.initializeOnSlave() // Initialize unionExpr. KEY has union field as the last field if there are distinct aggrs. unionExprEvaluator = initializeUnionExprEvaluator(rowInspector) @@ -90,6 +84,14 @@ class GroupByPostShuffleOperator extends GroupByPreShuffleOperator with HiveTopO valueSer1 = valueTableDesc.getDeserializerClass.newInstance() valueSer1.initialize(null, valueTableDesc.getProperties()) } + + override def createRemotes() { + super.createRemotes() + + var kvd = keyValueDescs() + keyTableDesc = kvd.head._2._1 + valueTableDesc = kvd.head._2._2 + } private def initializeKeyWrapperFactories() { distinctKeyAggrs.keySet.iterator.foreach { unionId => @@ -219,7 +221,8 @@ class GroupByPostShuffleOperator extends GroupByPreShuffleOperator with HiveTopO // No distinct keys. val aggregator = new Aggregator[Any, Any, ArrayBuffer[Any]]( GroupByAggregator.createCombiner _, GroupByAggregator.mergeValue _, null) - val hashedRdd = repartitionedRDD.mapPartitions(aggregator.combineValuesByKey(_), + val hashedRdd = repartitionedRDD.mapPartitionsWithContext( + (context, iter) => aggregator.combineValuesByKey(iter, context), preservesPartitioning = true) val op = OperatorSerializationWrapper(this) @@ -231,7 +234,7 @@ class GroupByPostShuffleOperator extends GroupByPreShuffleOperator with HiveTopO } def sortAggregate(iter: Iterator[_]) = { - logInfo("Running Post Shuffle Group-By") + logDebug("Running Post Shuffle Group-By") if (iter.hasNext) { // Sort based aggregation iterator. @@ -401,7 +404,7 @@ class GroupByPostShuffleOperator extends GroupByPreShuffleOperator with HiveTopO def hashAggregate(iter: Iterator[_]) = { // TODO: use MutableBytesWritable to avoid the array copy. val bytes = new BytesWritable() - logInfo("Running Post Shuffle Group-By") + logDebug("Running Post Shuffle Group-By") val outputCache = new Array[Object](keyFields.length + aggregationEvals.length) // The reusedRow is used to conform to Hive's expected row format. diff --git a/src/main/scala/shark/execution/GroupByPreShuffleOperator.scala b/src/main/scala/shark/execution/GroupByPreShuffleOperator.scala index f19ec7a6..443f858c 100755 --- a/src/main/scala/shark/execution/GroupByPreShuffleOperator.scala +++ b/src/main/scala/shark/execution/GroupByPreShuffleOperator.scala @@ -21,10 +21,10 @@ package org.apache.hadoop.hive.ql.exec import java.util.{ArrayList => JArrayList, HashMap => JHashMap} import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer import scala.reflect.BeanProperty import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.exec.{GroupByOperator => HiveGroupByOperator} import org.apache.hadoop.hive.ql.plan.{AggregationDesc, ExprNodeDesc, GroupByDesc} import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer @@ -38,7 +38,7 @@ import shark.execution.UnaryOperator /** * The pre-shuffle group by operator responsible for map side aggregations. */ -class GroupByPreShuffleOperator extends UnaryOperator[HiveGroupByOperator] { +class GroupByPreShuffleOperator extends UnaryOperator[GroupByDesc] { @BeanProperty var conf: GroupByDesc = _ @BeanProperty var minReductionHashAggr: Float = _ @@ -49,6 +49,7 @@ class GroupByPreShuffleOperator extends UnaryOperator[HiveGroupByOperator] { // The aggregation functions. @transient var aggregationEvals: Array[GenericUDAFEvaluator] = _ + @transient var aggregationObjectInspectors: Array[ObjectInspector] = _ // Key fields to be grouped. @transient var keyFields: Array[ExprNodeEvaluator] = _ @@ -60,22 +61,17 @@ class GroupByPreShuffleOperator extends UnaryOperator[HiveGroupByOperator] { @transient var aggregationParameterStandardObjectInspectors: Array[Array[ObjectInspector]] = _ @transient var aggregationIsDistinct: Array[Boolean] = _ + @transient var currentKeyObjectInspectors: Array[ObjectInspector] = _ - override def initializeOnMaster() { - conf = hiveOp.getConf() - minReductionHashAggr = hconf.get(HiveConf.ConfVars.HIVEMAPAGGRHASHMINREDUCTION.varname).toFloat - numRowsCompareHashAggr = hconf.get(HiveConf.ConfVars.HIVEGROUPBYMAPINTERVAL.varname).toInt - } - - override def initializeOnSlave() { + def createLocals() { aggregationEvals = conf.getAggregators.map(_.getGenericUDAFEvaluator).toArray aggregationIsDistinct = conf.getAggregators.map(_.getDistinct).toArray rowInspector = objectInspector.asInstanceOf[StructObjectInspector] keyFields = conf.getKeys().map(k => ExprNodeEvaluatorFactory.get(k)).toArray val keyObjectInspectors: Array[ObjectInspector] = keyFields.map(k => k.initialize(rowInspector)) - val currentKeyObjectInspectors = keyObjectInspectors.map { k => - ObjectInspectorUtils.getStandardObjectInspector(k, ObjectInspectorCopyOption.WRITABLE) - } + currentKeyObjectInspectors = keyObjectInspectors.map { k => + ObjectInspectorUtils.getStandardObjectInspector(k, ObjectInspectorCopyOption.WRITABLE) + } aggregationParameterFields = conf.getAggregators.toArray.map { aggr => aggr.asInstanceOf[AggregationDesc].getParameters.toArray.map { param => @@ -96,18 +92,57 @@ class GroupByPreShuffleOperator extends UnaryOperator[HiveGroupByOperator] { aggregationParameterObjectInspectors(pair._2)) } + aggregationObjectInspectors = + Array.tabulate[ObjectInspector](aggregationEvals.length) { i=> + val mode = conf.getAggregators()(i).getMode() + aggregationEvals(i).init(mode, aggregationParameterObjectInspectors(i)) + } + val keyFieldNames = conf.getOutputColumnNames.slice(0, keyFields.length) val totalFields = keyFields.length + aggregationEvals.length val keyois = new JArrayList[ObjectInspector](totalFields) keyObjectInspectors.foreach(keyois.add(_)) - keyObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(keyFieldNames, keyois) + keyObjectInspector = ObjectInspectorFactory. + getStandardStructObjectInspector(keyFieldNames, keyois) keyFactory = new KeyWrapperFactory(keyFields, keyObjectInspectors, currentKeyObjectInspectors) } + + def createRemotes() { + conf = desc + minReductionHashAggr = hconf.get(HiveConf.ConfVars.HIVEMAPAGGRHASHMINREDUCTION.varname).toFloat + numRowsCompareHashAggr = hconf.get(HiveConf.ConfVars.HIVEGROUPBYMAPINTERVAL.varname).toInt + } + + override def initializeOnMaster() { + super.initializeOnMaster() + + createRemotes() + createLocals() + } + + override def initializeOnSlave() { + super.initializeOnSlave() + createLocals() + } + // copied from the org.apache.hadoop.hive.ql.exec.GroupByOperator + override def outputObjectInspector() = { + val totalFields = keyFields.length + aggregationEvals.length + + val ois = new ArrayBuffer[ObjectInspector](totalFields) + ois ++= (currentKeyObjectInspectors) + ois ++= (aggregationObjectInspectors) + + val fieldNames = conf.getOutputColumnNames() + + import scala.collection.JavaConversions._ + ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, ois.toList) + } + override def processPartition(split: Int, iter: Iterator[_]) = { - logInfo("Running Pre-Shuffle Group-By") + logDebug("Running Pre-Shuffle Group-By") var numRowsInput = 0 var numRowsHashTbl = 0 var useHashAggr = true @@ -148,9 +183,9 @@ class GroupByPreShuffleOperator extends UnaryOperator[HiveGroupByOperator] { } else { logInfo("Mapside hash aggregation enabled") } - logInfo("#hash table="+numRowsHashTbl+" #rows="+ - numRowsInput+" reduction="+numRowsHashTbl.toFloat/numRowsInput+ - " minReduction="+minReductionHashAggr) + logInfo("#hash table=" + numRowsHashTbl + " #rows=" + + numRowsInput + " reduction=" + numRowsHashTbl.toFloat/numRowsInput + + " minReduction=" + minReductionHashAggr) } } diff --git a/src/main/scala/shark/execution/HadoopTableReader.scala b/src/main/scala/shark/execution/HadoopTableReader.scala new file mode 100644 index 00000000..11a0b3e7 --- /dev/null +++ b/src/main/scala/shark/execution/HadoopTableReader.scala @@ -0,0 +1,257 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.execution + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.Constants.META_TABLE_PARTITION_COLUMNS +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.SerializableWritable + +import shark.{SharkEnv, Utils} + + +/** + * Helper class for scanning tables stored in Hadoop - e.g., to read Hive tables that reside in the + * data warehouse directory. + */ +class HadoopTableReader(@transient _tableDesc: TableDesc, @transient _localHConf: HiveConf) + extends TableReader { + + // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless + // it is smaller than what Spark suggests. + private val _minSplitsPerRDD = math.max( + _localHConf.getInt("mapred.map.tasks", 1), SharkEnv.sc.defaultMinSplits) + + // Add security credentials before broadcasting the Hive configuration, which is used accross all + // reads done by an instance of this class. + HadoopTableReader.addCredentialsToConf(_localHConf) + private val _broadcastedHiveConf = SharkEnv.sc.broadcast(new SerializableWritable(_localHConf)) + + def broadcastedHiveConf = _broadcastedHiveConf + + def hiveConf = _broadcastedHiveConf.value.value + + override def makeRDDForTable( + hiveTable: HiveTable, + pruningFnOpt: Option[PruningFunctionType] = None + ): RDD[_] = + makeRDDForTable( + hiveTable, + _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + filterOpt = None) + + /** + * Creates a Hadoop RDD to read data from the target table's data directory. Returns a transformed + * RDD that contains deserialized rows. + * + * @param hiveTable Hive metadata for the table being scanned. + * @param deserializerClass Class of the SerDe used to deserialize Writables read from Hadoop. + * @param filterOpt If defined, then the filter is used to reject files contained in the data + * directory being read. If None, then all files are accepted. + */ + def makeRDDForTable( + hiveTable: HiveTable, + deserializerClass: Class[_ <: Deserializer], + filterOpt: Option[PathFilter]): RDD[_] = { + assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, + since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") + + // Create local references to member variables, so that the entire `this` object won't be + // serialized in the closure below. + val tableDesc = _tableDesc + val broadcastedHiveConf = _broadcastedHiveConf + + val tablePath = hiveTable.getPath + val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt) + + logDebug("Table input: %s".format(tablePath)) + val ifc = hiveTable.getInputFormatClass + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + + val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => + val hconf = broadcastedHiveConf.value.value + val deserializer = deserializerClass.newInstance().asInstanceOf[Deserializer] + deserializer.initialize(hconf, tableDesc.getProperties) + + // Deserialize each Writable to get the row value. + iter.map { value => + value match { + case v: Writable => deserializer.deserialize(v) + case _ => throw new RuntimeException("Failed to match " + value.toString) + } + } + } + deserializedHadoopRDD + } + + override def makeRDDForPartitionedTable( + partitions: Seq[HivePartition], + pruningFnOpt: Option[PruningFunctionType] = None + ): RDD[_] = { + val partitionToDeserializer = partitions.map(part => + (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap + makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) + } + + /** + * Create a HadoopRDD for every partition key specified in the query. Note that for on-disk Hive + * tables, a data directory is created for each partition corresponding to keys specified using + * 'PARTITION BY'. + * + * @param partitionToDeserializer Mapping from a Hive Partition metadata object to the SerDe + * class to use to deserialize input Writables from the corresponding partition. + * @param filterOpt If defined, then the filter is used to reject files contained in the data + * subdirectory of each partition being read. If None, then all files are accepted. + */ + def makeRDDForPartitionedTable( + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], + filterOpt: Option[PathFilter]): RDD[_] = { + val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => + val partDesc = Utilities.getPartitionDesc(partition) + val partPath = partition.getPartitionPath + val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) + val ifc = partDesc.getInputFileFormatClass + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + // Get partition field info + val partSpec = partDesc.getPartSpec() + val partProps = partDesc.getProperties() + + val partColsDelimited = partProps.getProperty(META_TABLE_PARTITION_COLUMNS) + // Partitioning columns are delimited by "/" + val partCols = partColsDelimited.trim().split("/").toSeq + // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'. + val partValues = if (partSpec == null) { + Array.fill(partCols.size)(new String) + } else { + partCols.map(col => new String(partSpec.get(col))).toArray + } + + // Create local references so that the outer object isn't serialized. + val tableDesc = _tableDesc + val broadcastedHiveConf = _broadcastedHiveConf + val localDeserializer = partDeserializer + + val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + hivePartitionRDD.mapPartitions { iter => + val hconf = broadcastedHiveConf.value.value + val rowWithPartArr = new Array[Object](2) + // Map each tuple to a row object + iter.map { value => + val deserializer = localDeserializer.newInstance() + deserializer.initialize(hconf, partProps) + val deserializedRow = deserializer.deserialize(value) // LazyStruct + rowWithPartArr.update(0, deserializedRow) + rowWithPartArr.update(1, partValues) + rowWithPartArr.asInstanceOf[Object] + } + } + }.toSeq + // Even if we don't use any partitions, we still need an empty RDD + if (hivePartitionRDDs.size == 0) { + new EmptyRDD[Object](SharkEnv.sc) + } else { + new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) + } + } + + /** + * If `filterOpt` is defined, then it will be used to filter files from `path`. These files are + * returned in a single, comma-separated string. + */ + private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = { + filterOpt match { + case Some(filter) => { + val fs = path.getFileSystem(_localHConf) + val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString) + filteredFiles.mkString(",") + } + case None => path.toString + } + } + + /** + * Creates a HadoopRDD based on the broadcasted HiveConf and other job properties that will be + * applied locally on each slave. + */ + private def createHadoopRdd( + tableDesc: TableDesc, + path: String, + inputFormatClass: Class[InputFormat[Writable, Writable]]) + : RDD[Writable] = { + val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ + + val rdd = new HadoopRDD( + SharkEnv.sc, + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + Some(initializeJobConfFunc), + inputFormatClass, + classOf[Writable], + classOf[Writable], + _minSplitsPerRDD) + + // Only take the value (skip the key) because Hive works only with values. + rdd.map(_._2) + } + +} + +object HadoopTableReader { + + /** + * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to + * instantiate a HadoopRDD. + */ + def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { + FileInputFormat.setInputPaths(jobConf, path) + if (tableDesc != null) { + Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) + } + val bufferSize = System.getProperty("spark.buffer.size", "65536") + jobConf.set("io.file.buffer.size", bufferSize) + } + + /** Adds S3 credentials to the `conf`. */ + def addCredentialsToConf(conf: Configuration) { + // Set s3/s3n credentials. Setting them in localJobConf ensures the settings propagate + // from Spark's master all the way to Spark's slaves. + var s3varsSet = false + Seq("fs.s3n.awsAccessKeyId", "fs.s3n.awsSecretAccessKey", + "fs.s3.awsAccessKeyId", "fs.s3.awsSecretAccessKey").foreach { variableName => + if (conf.get(variableName) != null) { + s3varsSet = true + } + } + + // If none of the s3 credentials are set in Hive conf, try use the environmental + // variables for credentials. + if (!s3varsSet) { + Utils.setAwsCredentials(conf) + } + } +} diff --git a/src/main/scala/shark/execution/HiveTopOperator.scala b/src/main/scala/shark/execution/HiveTopOperator.scala deleted file mode 100755 index 825c9f90..00000000 --- a/src/main/scala/shark/execution/HiveTopOperator.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (C) 2012 The Regents of The University California. - * All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package shark.execution - -import org.apache.hadoop.hive.ql.metadata.HiveException -import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector - -import shark.LogHelper - - -/** - * Operators that are top operators in Hive stages. This includes TableScan and - * everything that can come after ReduceSink. Note that they might have multiple - * upstream operators (multiple parents). - */ -trait HiveTopOperator extends LogHelper { - self: Operator[_ <: HiveOperator] => - - /** - * Stores the input object inspectors. This is passed down by either the - * upstream operators (i.e. ReduceSink) or in the case of TableScan, passed - * by the init code in SparkTask. - */ - @transient - val inputObjectInspectors = new scala.collection.mutable.HashMap[Int, ObjectInspector] - - /** - * Stores the deser for operators downstream from ReduceSink. This is set by - * ReduceSink.initializeDownStreamHiveOperators(). - */ - @transient - val keyValueTableDescs = new scala.collection.mutable.HashMap[Int, (TableDesc, TableDesc)] - - /** - * Initialize the Hive operator when all input object inspectors are ready. - */ - def initializeHiveTopOperator() { - logInfo("Started executing " + self + " initializeHiveTopOperator()") - - // Call initializeDownStreamHiveOperators() of upstream operators that are - // ReduceSink so we can get the proper input object inspectors and serdes. - val reduceSinkParents = self.parentOperators.filter(_.isInstanceOf[ReduceSinkOperator]) - reduceSinkParents.foreach { parent => - parent.asInstanceOf[ReduceSinkOperator].initializeDownStreamHiveOperator() - logInfo("parent : " + parent) - } - - // Only do initialize if all our input inspectors are ready. We use > - // instead of == since TableScan doesn't have parents, but have an object - // inspector. If == is used, table scan is skipped. - assert(inputObjectInspectors.size >= reduceSinkParents.size, - println("# input object inspectors (%d) < # reduce sink parent operators (%d)".format( - inputObjectInspectors.size, reduceSinkParents.size))) - - val objectInspectorArray = { - // Special case for single object inspector (non join case) because the - // joinTag is -1. - if (inputObjectInspectors.size == 1) { - Array(inputObjectInspectors.values.head) - } else { - val arr = new Array[ObjectInspector](inputObjectInspectors.size) - inputObjectInspectors foreach { case (tag, inspector) => arr(tag) = inspector } - arr - } - } - - if (objectInspectorArray.size > 0) { - // Initialize the hive operators. This init propagates downstream. - logDebug("Executing " + self.hiveOp + ".initialize()") - self.hiveOp.initialize(hconf, objectInspectorArray) - } - - logInfo("Finished executing " + self + " initializeHiveTopOperator()") - } - - def setInputObjectInspector(tag: Int, objectInspector: ObjectInspector) { - inputObjectInspectors.put(tag, objectInspector) - } - - def setKeyValueTableDescs(tag: Int, descs: (TableDesc, TableDesc)) { - keyValueTableDescs.put(tag, descs) - } - -} - diff --git a/src/main/scala/shark/execution/JoinOperator.scala b/src/main/scala/shark/execution/JoinOperator.scala index a641a264..21592b67 100755 --- a/src/main/scala/shark/execution/JoinOperator.scala +++ b/src/main/scala/shark/execution/JoinOperator.scala @@ -24,10 +24,10 @@ import scala.collection.JavaConversions._ import scala.reflect.BeanProperty import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.exec.{JoinOperator => HiveJoinOperator} import org.apache.hadoop.hive.ql.plan.{JoinDesc, TableDesc} -import org.apache.hadoop.hive.serde2.{Deserializer, Serializer, SerDeUtils} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorUtils, StandardStructObjectInspector} +import org.apache.hadoop.hive.serde2.{Deserializer, SerDeUtils} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector import org.apache.hadoop.io.BytesWritable import org.apache.spark.{CoGroupedRDD, HashPartitioner} @@ -36,8 +36,7 @@ import org.apache.spark.rdd.RDD import shark.execution.serialization.OperatorSerializationWrapper -class JoinOperator extends CommonJoinOperator[JoinDesc, HiveJoinOperator] - with HiveTopOperator { +class JoinOperator extends CommonJoinOperator[JoinDesc] with ReduceSinkTableDesc { @BeanProperty var valueTableDescMap: JHashMap[Int, TableDesc] = _ @BeanProperty var keyTableDesc: TableDesc = _ @@ -48,9 +47,10 @@ class JoinOperator extends CommonJoinOperator[JoinDesc, HiveJoinOperator] override def initializeOnMaster() { super.initializeOnMaster() + val descs = keyValueDescs() valueTableDescMap = new JHashMap[Int, TableDesc] - valueTableDescMap ++= keyValueTableDescs.map { case(tag, kvdescs) => (tag, kvdescs._2) } - keyTableDesc = keyValueTableDescs.head._2._1 + valueTableDescMap ++= descs.map { case(tag, kvdescs) => (tag, kvdescs._2) } + keyTableDesc = descs.head._2._1 // Call initializeOnSlave to initialize the join filters, etc. initializeOnSlave() @@ -113,7 +113,7 @@ class JoinOperator extends CommonJoinOperator[JoinDesc, HiveJoinOperator] op.initializeOnSlave() val writable = new BytesWritable - val nullSafes = op.conf.getNullSafes() + val nullSafes = conf.getNullSafes() val cp = new CartesianProduct[Any](op.numTables) diff --git a/src/main/scala/shark/execution/JoinUtil.scala b/src/main/scala/shark/execution/JoinUtil.scala index 2fb22731..512ad494 100644 --- a/src/main/scala/shark/execution/JoinUtil.scala +++ b/src/main/scala/shark/execution/JoinUtil.scala @@ -56,13 +56,17 @@ object JoinUtil { noOuterJoin: Boolean): Array[AnyRef] = { val isFiltered: Boolean = { - Range(0, filters.size()).exists { x => - val cond = filters.get(x).evaluate(row) - val result = Option[AnyRef]( - filtersOI.get(x).asInstanceOf[PrimitiveOI].getPrimitiveJavaObject(cond)) - result match { - case Some(u) => u.asInstanceOf[Boolean].unary_! - case None => true + if (filters == null) { + false + } else { + Range(0, filters.size()).exists { x => + val cond = filters.get(x).evaluate(row) + val result = Option[AnyRef]( + filtersOI.get(x).asInstanceOf[PrimitiveOI].getPrimitiveJavaObject(cond)) + result match { + case Some(u) => u.asInstanceOf[Boolean].unary_! + case None => true + } } } } @@ -78,7 +82,7 @@ object JoinUtil { if (noOuterJoin) { a } else { - val n = new Array[AnyRef](size+1) + val n = new Array[AnyRef](size + 1) Array.copy(a, 0, n, 0, size) n(size) = new SerializableWritable(new BooleanWritable(isFiltered)) n diff --git a/src/main/scala/shark/execution/LateralViewForwardOperator.scala b/src/main/scala/shark/execution/LateralViewForwardOperator.scala index 65efd679..458bd7c3 100755 --- a/src/main/scala/shark/execution/LateralViewForwardOperator.scala +++ b/src/main/scala/shark/execution/LateralViewForwardOperator.scala @@ -17,12 +17,12 @@ package shark.execution -import org.apache.hadoop.hive.ql.exec.{LateralViewForwardOperator => HiveLateralViewForwardOperator} +import org.apache.hadoop.hive.ql.plan.LateralViewForwardDesc import org.apache.spark.rdd.RDD -class LateralViewForwardOperator extends UnaryOperator[HiveLateralViewForwardOperator] { +class LateralViewForwardOperator extends UnaryOperator[LateralViewForwardDesc] { override def execute(): RDD[_] = executeParents().head._2 diff --git a/src/main/scala/shark/execution/LateralViewJoinOperator.scala b/src/main/scala/shark/execution/LateralViewJoinOperator.scala index 55164348..603e3b87 100755 --- a/src/main/scala/shark/execution/LateralViewJoinOperator.scala +++ b/src/main/scala/shark/execution/LateralViewJoinOperator.scala @@ -18,17 +18,18 @@ package shark.execution import java.nio.ByteBuffer -import java.util.ArrayList +import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.reflect.BeanProperty import org.apache.commons.codec.binary.Base64 import org.apache.hadoop.hive.ql.exec.{ExprNodeEvaluator, ExprNodeEvaluatorFactory} -import org.apache.hadoop.hive.ql.exec.{LateralViewJoinOperator => HiveLateralViewJoinOperator} -import org.apache.hadoop.hive.ql.plan.SelectDesc +import org.apache.hadoop.hive.ql.plan.{LateralViewJoinDesc, SelectDesc} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector} +import org.apache.spark.SparkEnv import org.apache.spark.rdd.RDD import org.apache.spark.serializer.{KryoSerializer => SparkKryoSerializer} @@ -39,7 +40,7 @@ import org.apache.spark.serializer.{KryoSerializer => SparkKryoSerializer} * Hive handles this by having two branches in its plan, then joining their output (see diagram in * LateralViewJoinOperator.java). We put all the explode logic here instead. */ -class LateralViewJoinOperator extends NaryOperator[HiveLateralViewJoinOperator] { +class LateralViewJoinOperator extends NaryOperator[LateralViewJoinDesc] { @BeanProperty var conf: SelectDesc = _ @BeanProperty var lvfOp: LateralViewForwardOperator = _ @@ -51,9 +52,10 @@ class LateralViewJoinOperator extends NaryOperator[HiveLateralViewJoinOperator] @transient var fieldOis: StructObjectInspector = _ override def initializeOnMaster() { + super.initializeOnMaster() // Get conf from Select operator beyond UDTF Op to get eval() conf = parentOperators.filter(_.isInstanceOf[UDTFOperator]).head - .parentOperators.head.asInstanceOf[SelectOperator].hiveOp.getConf() + .parentOperators.head.asInstanceOf[SelectOperator].desc udtfOp = parentOperators.filter(_.isInstanceOf[UDTFOperator]).head.asInstanceOf[UDTFOperator] udtfOIString = KryoSerializerToString.serialize(udtfOp.objectInspectors) @@ -76,6 +78,29 @@ class LateralViewJoinOperator extends NaryOperator[HiveLateralViewJoinOperator] udtfOp.initializeOnSlave() } + override def outputObjectInspector() = { + val SELECT_TAG = 0 + val UDTF_TAG = 1 + + val ois = new ArrayBuffer[ObjectInspector]() + val fieldNames = desc.getOutputInternalColNames() + + // The output of the lateral view join will be the columns from the select + // parent, followed by the column from the UDTF parent + var soi = objectInspectors(SELECT_TAG).asInstanceOf[StructObjectInspector] + + for (sf <- soi.getAllStructFieldRefs()) { + ois.add(sf.getFieldObjectInspector()); + } + + soi = objectInspectors(UDTF_TAG).asInstanceOf[StructObjectInspector] + for (sf <- soi.getAllStructFieldRefs()) { + ois.add(sf.getFieldObjectInspector()); + } + + ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, ois) + } + override def execute: RDD[_] = { // Execute LateralViewForwardOperator, bypassing Select / UDTF - Select // branches (see diagram in Hive's). @@ -127,7 +152,7 @@ class LateralViewJoinOperator extends NaryOperator[HiveLateralViewJoinOperator] */ object KryoSerializerToString { - @transient val kryoSer = new SparkKryoSerializer + @transient val kryoSer = new SparkKryoSerializer(SparkEnv.get.conf) def serialize[T](o: T): String = { val bytes = kryoSer.newInstance().serialize(o).array() diff --git a/src/main/scala/shark/execution/LimitOperator.scala b/src/main/scala/shark/execution/LimitOperator.scala index a66b0612..c78c0ab4 100755 --- a/src/main/scala/shark/execution/LimitOperator.scala +++ b/src/main/scala/shark/execution/LimitOperator.scala @@ -17,23 +17,21 @@ package shark.execution -import scala.collection.Iterator -import scala.reflect.BeanProperty +import org.apache.hadoop.hive.ql.plan.LimitDesc -import org.apache.hadoop.hive.ql.exec.{LimitOperator => HiveLimitOperator} - -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{EmptyRDD, RDD} import shark.SharkEnv -class LimitOperator extends UnaryOperator[HiveLimitOperator] { + +class LimitOperator extends UnaryOperator[LimitDesc] { // Only works on the master program. - def limit = hiveOp.getConf().getLimit() + def limit = desc.getLimit() override def execute(): RDD[_] = { - val limitNum = hiveOp.getConf().getLimit() + val limitNum = desc.getLimit() if (limitNum > 0) { // Take limit on each partition. diff --git a/src/main/scala/shark/execution/MapJoinOperator.scala b/src/main/scala/shark/execution/MapJoinOperator.scala index 32cdbcfe..7ea9bea6 100755 --- a/src/main/scala/shark/execution/MapJoinOperator.scala +++ b/src/main/scala/shark/execution/MapJoinOperator.scala @@ -17,20 +17,18 @@ package shark.execution -import java.util.{HashMap => JHashMap, List => JList} +import java.util.{ArrayList, HashMap => JHashMap, List => JList} -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.reflect.BeanProperty import org.apache.hadoop.hive.ql.exec.{ExprNodeEvaluator, JoinUtil => HiveJoinUtil} -import org.apache.hadoop.hive.ql.exec.{MapJoinOperator => HiveMapJoinOperator} import org.apache.hadoop.hive.ql.plan.MapJoinDesc import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory -import org.apache.spark.SparkEnv import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel import shark.SharkEnv import shark.execution.serialization.{OperatorSerializationWrapper, SerializableWritable} @@ -44,7 +42,7 @@ import shark.execution.serialization.{OperatorSerializationWrapper, Serializable * Different from Hive, we don't spill the hash tables to disk. If the "small" * tables are too big to fit in memory, the normal join should be used anyway. */ -class MapJoinOperator extends CommonJoinOperator[MapJoinDesc, HiveMapJoinOperator] { +class MapJoinOperator extends CommonJoinOperator[MapJoinDesc] { @BeanProperty var posBigTable: Int = _ @BeanProperty var bigTableAlias: Int = _ @@ -81,6 +79,30 @@ class MapJoinOperator extends CommonJoinOperator[MapJoinDesc, HiveMapJoinOperato joinKeys, objectInspectors.toArray, CommonJoinOperator.NOTSKIPBIGTABLE) } + + // copied from the org.apache.hadoop.hive.ql.exec.AbstractMapJoinOperator + override def outputObjectInspector() = { + var outputObjInspector = super.outputObjectInspector() + val structFields = outputObjInspector.asInstanceOf[StructObjectInspector] + .getAllStructFieldRefs() + if (conf.getOutputColumnNames().size() < structFields.size()) { + var structFieldObjectInspectors = new ArrayList[ObjectInspector]() + for (alias <- order) { + var sz = conf.getExprs().get(alias).size() + var retained = conf.getRetainList().get(alias) + for (i <- 0 to sz - 1) { + var pos = retained.get(i) + structFieldObjectInspectors.add(structFields.get(pos).getFieldObjectInspector()) + } + } + outputObjInspector = ObjectInspectorFactory + .getStandardStructObjectInspector( + conf.getOutputColumnNames(), + structFieldObjectInspectors) + } + + outputObjInspector + } override def execute(): RDD[_] = { val inputRdds = executeParents() @@ -92,8 +114,8 @@ class MapJoinOperator extends CommonJoinOperator[MapJoinDesc, HiveMapJoinOperato } override def combineMultipleRdds(rdds: Seq[(Int, RDD[_])]): RDD[_] = { - logInfo("%d small tables to map join a large table (%d)".format(rdds.size - 1, posBigTable)) - logInfo("Big table alias " + bigTableAlias) + logDebug("%d small tables to map join a large table (%d)".format(rdds.size - 1, posBigTable)) + logDebug("Big table alias " + bigTableAlias) val op1 = OperatorSerializationWrapper(this) @@ -102,7 +124,7 @@ class MapJoinOperator extends CommonJoinOperator[MapJoinDesc, HiveMapJoinOperato // Build hash tables for the small tables. val hashtables = rdds.zipWithIndex.filter(_._2 != bigTableAlias).map { case ((_, rdd), pos) => - logInfo("Creating hash table for input %d".format(pos)) + logDebug("Creating hash table for input %d".format(pos)) // First compute the keys and values of the small RDDs on slaves. // We need to do this before collecting the RDD because the RDD might @@ -114,6 +136,7 @@ class MapJoinOperator extends CommonJoinOperator[MapJoinDesc, HiveMapJoinOperato // following mapParititons will fail because it tries to include the // outer closure, which references "this". val op = op1 + // An RDD of (Join key, Corresponding rows) tuples. val rddForHash: RDD[(Seq[AnyRef], Seq[Array[AnyRef]])] = rdd.mapPartitions { partition => op.initializeOnSlave() @@ -125,28 +148,14 @@ class MapJoinOperator extends CommonJoinOperator[MapJoinDesc, HiveMapJoinOperato // Collect the RDD and build a hash table. val startCollect = System.currentTimeMillis() - val storageLevel = rddForHash.getStorageLevel - if(storageLevel == StorageLevel.NONE) - rddForHash.persist(StorageLevel.MEMORY_AND_DISK) - rddForHash.foreach(_ => Unit) - val wrappedRows = rddForHash.partitions.flatMap { part => - val blockId = "rdd_%s_%s".format(rddForHash.id, part.index) - val iter = SparkEnv.get.blockManager.get(blockId) - val partRows = new ArrayBuffer[(Seq[AnyRef], Seq[Array[AnyRef]])] - iter.foreach(_.foreach { row => - partRows += row.asInstanceOf[(Seq[AnyRef], Seq[Array[AnyRef]])] - }) - partRows - } - if(storageLevel == StorageLevel.NONE) - rddForHash.unpersist() + val collectedRows: Array[(Seq[AnyRef], Seq[Array[AnyRef]])] = rddForHash.collect() - logInfo("wrappedRows size:" + wrappedRows.size) + logDebug("collectedRows size:" + collectedRows.size) val collectTime = System.currentTimeMillis() - startCollect logInfo("HashTable collect took " + collectTime + " ms") // Build the hash table. - val hash = wrappedRows.groupBy(x => x._1) + val hash = collectedRows.groupBy(x => x._1) .mapValues(v => v.flatMap(t => t._2)) val map = new JHashMap[Seq[AnyRef], Array[Array[AnyRef]]]() diff --git a/src/main/scala/shark/execution/MapSplitPruning.scala b/src/main/scala/shark/execution/MapSplitPruning.scala index 5959660b..976ebf50 100644 --- a/src/main/scala/shark/execution/MapSplitPruning.scala +++ b/src/main/scala/shark/execution/MapSplitPruning.scala @@ -17,20 +17,18 @@ package org.apache.hadoop.hive.ql.exec -import java.sql.Timestamp - -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual +import org.apache.hadoop.hive.serde2.objectinspector.{MapSplitPruningHelper, StructField} +import org.apache.hadoop.hive.serde2.objectinspector.UnionStructObjectInspector.MyField import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseCompare import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrGreaterThan import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrLessThan -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector -import org.apache.hadoop.io.Text +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr import shark.memstore2.ColumnarStructObjectInspector.IDStructField import shark.memstore2.TablePartitionStats @@ -57,13 +55,23 @@ object MapSplitPruning { e.genericUDF match { case _: GenericUDFOPAnd => test(s, e.children(0)) && test(s, e.children(1)) case _: GenericUDFOPOr => test(s, e.children(0)) || test(s, e.children(1)) - case _: GenericUDFBetween => - testBetweenPredicate(s, e.children(0).asInstanceOf[ExprNodeConstantEvaluator], - e.children(1).asInstanceOf[ExprNodeColumnEvaluator], - e.children(2).asInstanceOf[ExprNodeConstantEvaluator], - e.children(3).asInstanceOf[ExprNodeConstantEvaluator]) + case _: GenericUDFBetween => + val col = e.children(1) + if (col.isInstanceOf[ExprNodeColumnEvaluator]) { + testBetweenPredicate(s, e.children(0).asInstanceOf[ExprNodeConstantEvaluator], + col.asInstanceOf[ExprNodeColumnEvaluator], + e.children(2).asInstanceOf[ExprNodeConstantEvaluator], + e.children(3).asInstanceOf[ExprNodeConstantEvaluator]) + } else { + //cannot prune function based evaluators in general. + true + } + case _: GenericUDFIn => - testInPredicate(s, e.children(0).asInstanceOf[ExprNodeColumnEvaluator], e.children.drop(1)) + testInPredicate( + s, + e.children(0).asInstanceOf[ExprNodeColumnEvaluator], + e.children.drop(1)) case udf: GenericUDFBaseCompare => testComparisonPredicate(s, udf, e.children(0), e.children(1)) case _ => true @@ -79,7 +87,7 @@ object MapSplitPruning { columnEval: ExprNodeColumnEvaluator, expEvals: Array[ExprNodeEvaluator]): Boolean = { - val field = columnEval.field.asInstanceOf[IDStructField] + val field = getIDStructField(columnEval.field) val columnStats = s.stats(field.fieldID) if (columnStats != null) { @@ -93,28 +101,29 @@ object MapSplitPruning { true } } - + def testBetweenPredicate( s: TablePartitionStats, invertEval: ExprNodeConstantEvaluator, columnEval: ExprNodeColumnEvaluator, leftEval: ExprNodeConstantEvaluator, rightEval: ExprNodeConstantEvaluator): Boolean = { - - val field = columnEval.field.asInstanceOf[IDStructField] + + val field = getIDStructField(columnEval.field) val columnStats = s.stats(field.fieldID) val leftValue: Object = leftEval.expr.getValue val rightValue: Object = rightEval.expr.getValue val invertValue: Boolean = invertEval.expr.getValue.asInstanceOf[Boolean] - + if (columnStats != null) { - val exists = (columnStats :>< (leftValue , rightValue)) - if (invertValue) !exists else exists - } else { - // If there is no stats on the column, don't prune. - true - } + val exists = (columnStats :>< (leftValue , rightValue)) + if (invertValue) !exists else exists + } else { + // If there is no stats on the column, don't prune. + true + } } + /** * Test whether we should keep the split as a candidate given the comparison * predicate. Return true if the split should be kept as a candidate, false if @@ -128,24 +137,28 @@ object MapSplitPruning { // Try to get the column evaluator. val columnEval: ExprNodeColumnEvaluator = - if (left.isInstanceOf[ExprNodeColumnEvaluator]) + if (left.isInstanceOf[ExprNodeColumnEvaluator]) { left.asInstanceOf[ExprNodeColumnEvaluator] - else if (right.isInstanceOf[ExprNodeColumnEvaluator]) + } else if (right.isInstanceOf[ExprNodeColumnEvaluator]) { right.asInstanceOf[ExprNodeColumnEvaluator] - else null + } else { + null + } // Try to get the constant value. val constEval: ExprNodeConstantEvaluator = - if (left.isInstanceOf[ExprNodeConstantEvaluator]) + if (left.isInstanceOf[ExprNodeConstantEvaluator]) { left.asInstanceOf[ExprNodeConstantEvaluator] - else if (right.isInstanceOf[ExprNodeConstantEvaluator]) + } else if (right.isInstanceOf[ExprNodeConstantEvaluator]) { right.asInstanceOf[ExprNodeConstantEvaluator] - else null + } else { + null + } if (columnEval != null && constEval != null) { // We can prune the partition only if it is a predicate of form // column op const, where op is <, >, =, <=, >=, !=. - val field = columnEval.field.asInstanceOf[IDStructField] + val field = getIDStructField(columnEval.field) val value: Object = constEval.expr.getValue val columnStats = s.stats(field.fieldID) @@ -167,4 +180,17 @@ object MapSplitPruning { true } } + + private def getIDStructField(field: StructField): IDStructField = field match { + case myField: MyField => { + // For partitioned tables, the ColumnarStruct's IDStructFields are enclosed inside + // the Hive UnionStructObjectInspector's MyField objects. + MapSplitPruningHelper.getStructFieldFromUnionOIField(myField) + .asInstanceOf[IDStructField] + } + case idStructField: IDStructField => idStructField + case otherFieldType: Any => { + throw new Exception("Unrecognized StructField: " + otherFieldType) + } + } } diff --git a/src/main/scala/shark/execution/MapSplitPruningHelper.scala b/src/main/scala/shark/execution/MapSplitPruningHelper.scala new file mode 100644 index 00000000..35a1041b --- /dev/null +++ b/src/main/scala/shark/execution/MapSplitPruningHelper.scala @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.serde2.objectinspector + +import org.apache.hadoop.hive.serde2.objectinspector.UnionStructObjectInspector.MyField + + +object MapSplitPruningHelper { + + /** + * Extract the UnionStructObjectInspector.MyField's `structField` reference, which is + * package-private. + */ + def getStructFieldFromUnionOIField(unionOIMyField: MyField): StructField = { + unionOIMyField.structField + } + +} diff --git a/src/main/scala/shark/execution/MemoryStoreSinkOperator.scala b/src/main/scala/shark/execution/MemoryStoreSinkOperator.scala index fc989885..856e101c 100644 --- a/src/main/scala/shark/execution/MemoryStoreSinkOperator.scala +++ b/src/main/scala/shark/execution/MemoryStoreSinkOperator.scala @@ -24,7 +24,7 @@ import scala.reflect.BeanProperty import org.apache.hadoop.io.Writable -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.storage.StorageLevel import shark.{SharkConfVars, SharkEnv} @@ -38,12 +38,36 @@ import shark.tachyon.TachyonTableWriter */ class MemoryStoreSinkOperator extends TerminalOperator { + // The initial capacity for ArrayLists used to construct the columnar storage. If -1, + // the ColumnarSerde will obtain the partition size from a Configuration during execution + // initialization (see ColumnarSerde#initialize()). @BeanProperty var partitionSize: Int = _ + + // If true, columnar storage will use compression. @BeanProperty var shouldCompress: Boolean = _ - @BeanProperty var storageLevel: StorageLevel = _ + + // For CTAS, this is the name of the table that is created. For INSERTS, this is the name of* + // the table that is modified. @BeanProperty var tableName: String = _ - @transient var useTachyon: Boolean = _ - @transient var useUnionRDD: Boolean = _ + + // The Hive metastore DB that the `tableName` table belongs to. + @BeanProperty var databaseName: String = _ + + // Used only for commands that target Hive partitions. The partition key is a set of unique values + // for the the table's partitioning columns and identifies the partition (represented by an RDD) + // that will be created or modified by the INSERT command being handled. + @BeanProperty var hivePartitionKeyOpt: Option[String] = _ + + // The memory storage used to store the output RDD - e.g., CacheType.HEAP refers to Spark's + // block manager. + @transient var cacheMode: CacheType.CacheType = _ + + // Whether to compose a UnionRDD from the output RDD and a previous RDD. For example, for an + // INSERT INTO command, the previous RDD will contain the contents of the 'tableName'. + @transient var isInsertInto: Boolean = _ + + // The number of columns in the schema for the table corresponding to 'tableName'. Used only + // to create a TachyonTableWriter, if Tachyon is used. @transient var numColumns: Int = _ override def initializeOnMaster() { @@ -63,18 +87,23 @@ class MemoryStoreSinkOperator extends TerminalOperator { val statsAcc = SharkEnv.sc.accumulableCollection(ArrayBuffer[(Int, TablePartitionStats)]()) val op = OperatorSerializationWrapper(this) + val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) val tachyonWriter: TachyonTableWriter = - if (useTachyon) { + if (cacheMode == CacheType.TACHYON) { + if (!isInsertInto && SharkEnv.tachyonUtil.tableExists(tableKey, hivePartitionKeyOpt)) { + // For INSERT OVERWRITE, delete the old table or Hive partition directory, if it exists. + SharkEnv.tachyonUtil.dropTable(tableKey, hivePartitionKeyOpt) + } // Use an additional row to store metadata (e.g. number of rows in each partition). - SharkEnv.tachyonUtil.createTableWriter(tableName, numColumns + 1) + SharkEnv.tachyonUtil.createTableWriter(tableKey, hivePartitionKeyOpt, numColumns + 1) } else { null } // Put all rows of the table into a set of TablePartition's. Each partition contains // only one TablePartition object. - var rdd: RDD[TablePartition] = inputRdd.mapPartitionsWithIndex { case(partitionIndex, iter) => + var outputRDD: RDD[TablePartition] = inputRdd.mapPartitionsWithIndex { case (part, iter) => op.initializeOnSlave() val serde = new ColumnarSerDe serde.initialize(op.localHconf, op.localHiveOp.getConf.getTableInfo.getProperties) @@ -86,103 +115,95 @@ class MemoryStoreSinkOperator extends TerminalOperator { builder = serde.serialize(row.asInstanceOf[AnyRef], op.objectInspector) } - if (builder != null) { - statsAcc += Tuple2(partitionIndex, builder.asInstanceOf[TablePartitionBuilder].stats) - Iterator(builder.asInstanceOf[TablePartitionBuilder].build) - } else { + if (builder == null) { // Empty partition. - statsAcc += Tuple2(partitionIndex, new TablePartitionStats(Array(), 0)) + statsAcc += Tuple2(part, new TablePartitionStats(Array(), 0)) Iterator(new TablePartition(0, Array())) + } else { + statsAcc += Tuple2(part, builder.asInstanceOf[TablePartitionBuilder].stats) + Iterator(builder.asInstanceOf[TablePartitionBuilder].build) } } if (tachyonWriter != null) { // Put the table in Tachyon. - op.logInfo("Putting RDD for %s in Tachyon".format(tableName)) - - SharkEnv.memoryMetadataManager.put(tableName, rdd) - + op.logInfo("Putting RDD for %s.%s in Tachyon".format(databaseName, tableName)) tachyonWriter.createTable(ByteBuffer.allocate(0)) - rdd = rdd.mapPartitionsWithIndex { case(partitionIndex, iter) => + outputRDD = outputRDD.mapPartitionsWithIndex { case(part, iter) => val partition = iter.next() partition.toTachyon.zipWithIndex.foreach { case(buf, column) => - tachyonWriter.writeColumnPartition(column, partitionIndex, buf) + tachyonWriter.writeColumnPartition(column, part, buf) } Iterator(partition) } // Force evaluate so the data gets put into Tachyon. - rdd.context.runJob(rdd, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit)) + outputRDD.context.runJob( + outputRDD, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit)) } else { - // Put the table in Spark block manager. - op.logInfo("Putting %sRDD for %s in Spark block manager, %s %s %s %s".format( - if (useUnionRDD) "Union" else "", - tableName, - if (storageLevel.deserialized) "deserialized" else "serialized", - if (storageLevel.useMemory) "in memory" else "", - if (storageLevel.useMemory && storageLevel.useDisk) "and" else "", - if (storageLevel.useDisk) "on disk" else "")) - - // Force evaluate so the data gets put into Spark block manager. - rdd.persist(storageLevel) - rdd.context.runJob(rdd, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit)) - - val origRdd = rdd - if (useUnionRDD) { - // If this is an insert, find the existing RDD and create a union of the two, and then - // put the union into the meta data tracker. - rdd = rdd.union( - SharkEnv.memoryMetadataManager.get(tableName).get.asInstanceOf[RDD[TablePartition]]) + // Run a job on the RDD that contains the query output to force the data into the memory + // store. The statistics will also be collected by 'statsAcc' during job execution. + if (cacheMode == CacheType.MEMORY) { + outputRDD.persist(StorageLevel.MEMORY_AND_DISK) + } else if (cacheMode == CacheType.MEMORY_ONLY) { + outputRDD.persist(StorageLevel.MEMORY_ONLY) } - SharkEnv.memoryMetadataManager.put(tableName, rdd) - rdd.setName(tableName) - - // Run a job on the original RDD to force it to go into cache. - origRdd.context.runJob(origRdd, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit)) + outputRDD.context.runJob( + outputRDD, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit)) } - // Report remaining memory. - /* Commented out for now waiting for the reporting code to make into Spark. - val remainingMems: Map[String, (Long, Long)] = SharkEnv.sc.getSlavesMemoryStatus - remainingMems.foreach { case(slave, mem) => - println("%s: %s / %s".format( - slave, - Utils.memoryBytesToString(mem._2), - Utils.memoryBytesToString(mem._1))) - } - println("Summary: %s / %s".format( - Utils.memoryBytesToString(remainingMems.map(_._2._2).sum), - Utils.memoryBytesToString(remainingMems.map(_._2._1).sum))) - */ - - val columnStats = - if (useUnionRDD) { - // Combine stats for the two tables being combined. - val numPartitions = statsAcc.value.toMap.size - val currentStats = statsAcc.value - val otherIndexToStats = SharkEnv.memoryMetadataManager.getStats(tableName).get - for ((otherIndex, tableStats) <- otherIndexToStats) { - currentStats.append((otherIndex + numPartitions, tableStats)) - } - currentStats.toMap - } else { + // Put the table in Spark block manager or Tachyon. + op.logInfo("Putting %sRDD for %s.%s in %s store".format( + if (isInsertInto) "Union" else "", + databaseName, + tableName, + if (cacheMode == CacheType.NONE) "disk" else cacheMode.toString)) + + val tableStats = + if (cacheMode == CacheType.TACHYON) { + tachyonWriter.updateMetadata(ByteBuffer.wrap(JavaSerializer.serialize(statsAcc.value.toMap))) statsAcc.value.toMap + } else { + val isHivePartitioned = SharkEnv.memoryMetadataManager.isHivePartitioned( + databaseName, tableName) + if (isHivePartitioned) { + val partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable( + databaseName, tableName).get + val hivePartitionKey = hivePartitionKeyOpt.get + outputRDD.setName("%s.%s(%s)".format(databaseName, tableName, hivePartitionKey)) + if (isInsertInto) { + // An RDD for the Hive partition already exists, so update its metadata entry in + // 'partitionedTable'. + assert(outputRDD.isInstanceOf[UnionRDD[_]]) + partitionedTable.updatePartition(hivePartitionKey, outputRDD, statsAcc.value) + } else { + // This is a new Hive-partition. Add a new metadata entry in 'partitionedTable'. + partitionedTable.putPartition(hivePartitionKey, outputRDD, statsAcc.value.toMap) + } + // Stats should be updated at this point. + partitionedTable.getStats(hivePartitionKey).get + } else { + outputRDD.setName(tableName) + // Create a new MemoryTable entry if one doesn't exist (i.e., this operator is for a CTAS). + val memoryTable = SharkEnv.memoryMetadataManager.getMemoryTable(databaseName, tableName) + .getOrElse(SharkEnv.memoryMetadataManager.createMemoryTable( + databaseName, tableName, cacheMode)) + if (isInsertInto) { + // Ok, a Tachyon table should manage stats for each rdd, and never union the maps. + memoryTable.update(outputRDD, statsAcc.value) + } else { + memoryTable.put(outputRDD, statsAcc.value.toMap) + } + memoryTable.getStats.get + } } - // Get the column statistics back to the cache manager. - SharkEnv.memoryMetadataManager.putStats(tableName, columnStats) - - if (tachyonWriter != null) { - tachyonWriter.updateMetadata(ByteBuffer.wrap(JavaSerializer.serialize(columnStats))) - } - if (SharkConfVars.getBoolVar(localHconf, SharkConfVars.MAP_PRUNING_PRINT_DEBUG)) { - columnStats.foreach { case(index, tableStats) => - println("Partition " + index + " " + tableStats.toString) + tableStats.foreach { case(index, tablePartitionStats) => + println("Partition " + index + " " + tablePartitionStats.toString) } } - // Return the cached RDD. - rdd + return outputRDD } override def processPartition(split: Int, iter: Iterator[_]): Iterator[_] = diff --git a/src/main/scala/shark/execution/Operator.scala b/src/main/scala/shark/execution/Operator.scala index bee7a566..a350a8a8 100755 --- a/src/main/scala/shark/execution/Operator.scala +++ b/src/main/scala/shark/execution/Operator.scala @@ -17,13 +17,17 @@ package shark.execution -import java.util.{List => JavaList} +import scala.language.existentials +import java.util.{List => JavaList} import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator import org.apache.spark.rdd.RDD @@ -31,7 +35,7 @@ import shark.LogHelper import shark.execution.serialization.OperatorSerializationWrapper -abstract class Operator[T <: HiveOperator] extends LogHelper with Serializable { +abstract class Operator[+T <: HiveDesc] extends LogHelper with Serializable { /** * Initialize the operator on master node. This can have dependency on other @@ -60,7 +64,7 @@ abstract class Operator[T <: HiveOperator] extends LogHelper with Serializable { */ def initializeMasterOnAll() { _parentOperators.foreach(_.initializeMasterOnAll()) - objectInspectors ++= hiveOp.getInputObjInspectors() + objectInspectors = inputObjectInspectors() initializeOnMaster() } @@ -79,18 +83,18 @@ abstract class Operator[T <: HiveOperator] extends LogHelper with Serializable { * Return the parent operators as a Java List. This is for interoperability * with Java. We use this in explain's Java code. */ - def parentOperatorsAsJavaList: JavaList[Operator[_]] = _parentOperators + def parentOperatorsAsJavaList: JavaList[Operator[_<:HiveDesc]] = _parentOperators - def addParent(parent: Operator[_]) { + def addParent(parent: Operator[_<:HiveDesc]) { _parentOperators += parent parent.childOperators += this } - def addChild(child: Operator[_]) { + def addChild(child: Operator[_<:HiveDesc]) { child.addParent(this) } - def returnTerminalOperators(): Seq[Operator[_]] = { + def returnTerminalOperators(): Seq[Operator[_<:HiveDesc]] = { if (_childOperators == null || _childOperators.size == 0) { Seq(this) } else { @@ -106,14 +110,134 @@ abstract class Operator[T <: HiveOperator] extends LogHelper with Serializable { } } - @transient var hiveOp: T = _ - @transient private val _childOperators = new ArrayBuffer[Operator[_]]() - @transient private val _parentOperators = new ArrayBuffer[Operator[_]]() - @transient var objectInspectors = new ArrayBuffer[ObjectInspector] + def desc = _desc + + def setDesc[B >: T](d: B) {_desc = d.asInstanceOf[T]} + + @transient private[this] var _desc: T = _ + @transient private val _childOperators = new ArrayBuffer[Operator[_<:HiveDesc]]() + @transient private val _parentOperators = new ArrayBuffer[Operator[_<:HiveDesc]]() + @transient var objectInspectors: Seq[ObjectInspector] =_ protected def executeParents(): Seq[(Int, RDD[_])] = { parentOperators.map(p => (p.getTag, p.execute())) } + + protected def inputObjectInspectors(): Seq[ObjectInspector] = { + if (null != _parentOperators) { + _parentOperators.sortBy(_.getTag).map(_.outputObjectInspector) + } else { + Seq.empty[ObjectInspector] + } + } + + // derived classes can set this to different object if needed, default is the first input OI + def outputObjectInspector(): ObjectInspector = objectInspectors(0) + + /** + * Copy from the org.apache.hadoop.hive.ql.exec.ReduceSinkOperator + * Initializes array of ExprNodeEvaluator. Adds Union field for distinct + * column indices for group by. + * Puts the return values into a StructObjectInspector with output column + * names. + * + * If distinctColIndices is empty, the object inspector is same as + * {@link Operator#initEvaluatorsAndReturnStruct(ExprNodeEvaluator[], List, ObjectInspector)} + */ + protected def initEvaluatorsAndReturnStruct( + evals: Array[ExprNodeEvaluator] , distinctColIndices: JavaList[JavaList[Integer]] , + outputColNames: JavaList[String], length: Int, rowInspector: ObjectInspector): + StructObjectInspector = { + + val fieldObjectInspectors = initEvaluators(evals, 0, length, rowInspector); + initEvaluatorsAndReturnStruct(evals, fieldObjectInspectors, distinctColIndices, + outputColNames, length, rowInspector) + } + + /** + * Copy from the org.apache.hadoop.hive.ql.exec.ReduceSinkOperator + * Initializes array of ExprNodeEvaluator. Adds Union field for distinct + * column indices for group by. + * Puts the return values into a StructObjectInspector with output column + * names. + * + * If distinctColIndices is empty, the object inspector is same as + * {@link Operator#initEvaluatorsAndReturnStruct(ExprNodeEvaluator[], List, ObjectInspector)} + */ + protected def initEvaluatorsAndReturnStruct( + evals: Array[ExprNodeEvaluator], fieldObjectInspectors: Array[ObjectInspector], + distinctColIndices: JavaList[JavaList[Integer]], outputColNames: JavaList[String], + length: Int, rowInspector: ObjectInspector): StructObjectInspector = { + + val inspectorLen = if (evals.length > length) length + 1 else evals.length + + val sois = new ArrayBuffer[ObjectInspector](inspectorLen) + + // keys + // var fieldObjectInspectors = initEvaluators(evals, 0, length, rowInspector); + sois ++= fieldObjectInspectors + + if (evals.length > length) { + // union keys + val uois = new ArrayBuffer[ObjectInspector]() + for (/*List*/ distinctCols <- distinctColIndices) { + val names = new ArrayBuffer[String]() + val eois = new ArrayBuffer[ObjectInspector]() + var numExprs = 0 + for (i <- distinctCols) { + names.add(HiveConf.getColumnInternalName(numExprs)) + eois.add(evals(i).initialize(rowInspector)) + numExprs += 1 + } + uois.add(ObjectInspectorFactory.getStandardStructObjectInspector(names, eois)) + } + + sois.add(ObjectInspectorFactory.getStandardUnionObjectInspector(uois)) + } + + ObjectInspectorFactory.getStandardStructObjectInspector(outputColNames, sois) + } + + /** + * Initialize an array of ExprNodeEvaluator and return the result + * ObjectInspectors. + */ + protected def initEvaluators(evals: Array[ExprNodeEvaluator], + rowInspector: ObjectInspector): Array[ObjectInspector] = { + val result = new Array[ObjectInspector](evals.length) + for (i <- 0 to evals.length -1) { + result(i) = evals(i).initialize(rowInspector) + } + + result + } + + /** + * Initialize an array of ExprNodeEvaluator from start, for specified length + * and return the result ObjectInspectors. + */ + protected def initEvaluators(evals: Array[ExprNodeEvaluator], + start: Int, length: Int,rowInspector: ObjectInspector): Array[ObjectInspector] = { + val result = new Array[ObjectInspector](length) + + for (i <- 0 to length - 1) { + result(i) = evals(start + i).initialize(rowInspector) + } + + result + } + + /** + * Initialize an array of ExprNodeEvaluator and put the return values into a + * StructObjectInspector with integer field names. + */ + protected def initEvaluatorsAndReturnStruct( + evals: Array[ExprNodeEvaluator], outputColName: JavaList[String], + rowInspector: ObjectInspector): StructObjectInspector = { + val fieldObjectInspectors = initEvaluators(evals, rowInspector) + return ObjectInspectorFactory.getStandardStructObjectInspector( + outputColName, fieldObjectInspectors.toList) + } } @@ -132,7 +256,7 @@ abstract class Operator[T <: HiveOperator] extends LogHelper with Serializable { * processPartition before sending it downstream. * */ -abstract class NaryOperator[T <: HiveOperator] extends Operator[T] { +abstract class NaryOperator[T <: HiveDesc] extends Operator[T] { /** Process a partition. Called on slaves. */ def processPartition(split: Int, iter: Iterator[_]): Iterator[_] @@ -168,7 +292,7 @@ abstract class NaryOperator[T <: HiveOperator] extends Operator[T] { * processPartition before sending it downstream. * */ -abstract class UnaryOperator[T <: HiveOperator] extends Operator[T] { +abstract class UnaryOperator[T <: HiveDesc] extends Operator[T] { /** Process a partition. Called on slaves. */ def processPartition(split: Int, iter: Iterator[_]): Iterator[_] @@ -192,7 +316,7 @@ abstract class UnaryOperator[T <: HiveOperator] extends Operator[T] { } -abstract class TopOperator[T <: HiveOperator] extends UnaryOperator[T] +abstract class TopOperator[T <: HiveDesc] extends UnaryOperator[T] object Operator extends LogHelper { @@ -205,7 +329,7 @@ object Operator extends LogHelper { * to do logging, but calling logging automatically adds a reference to the * operator (which is not serializable by Java) in the Spark closure. */ - def executeProcessPartition(operator: Operator[_ <: HiveOperator], rdd: RDD[_]): RDD[_] = { + def executeProcessPartition(operator: Operator[_ <: HiveDesc], rdd: RDD[_]): RDD[_] = { val op = OperatorSerializationWrapper(operator) rdd.mapPartitionsWithIndex { case(split, partition) => op.logDebug("Started executing mapPartitions for operator: " + op) diff --git a/src/main/scala/shark/execution/OperatorFactory.scala b/src/main/scala/shark/execution/OperatorFactory.scala index fbae568a..ecc0479b 100755 --- a/src/main/scala/shark/execution/OperatorFactory.scala +++ b/src/main/scala/shark/execution/OperatorFactory.scala @@ -20,11 +20,11 @@ package shark.execution import scala.collection.JavaConversions._ import org.apache.hadoop.hive.ql.exec.{GroupByPostShuffleOperator, GroupByPreShuffleOperator} +import org.apache.hadoop.hive.ql.exec.{Operator => HOperator} import org.apache.hadoop.hive.ql.metadata.HiveException -import org.apache.spark.storage.StorageLevel - import shark.LogHelper +import shark.memstore2.CacheType._ /** @@ -37,98 +37,117 @@ object OperatorFactory extends LogHelper { * uses Shark operators. This function automatically finds the Hive terminal * operator, and replicate the plan recursively up. */ - def createSharkPlan(hiveOp: HiveOperator): TerminalOperator = { + def createSharkPlan[T<:HiveDesc](hiveOp: HOperator[T]): TerminalOperator = { val hiveTerminalOp = _findHiveTerminalOperator(hiveOp) _createOperatorTree(hiveTerminalOp).asInstanceOf[TerminalOperator] } def createSharkMemoryStoreOutputPlan( - hiveTerminalOp: HiveOperator, + hiveTerminalOp: HOperator[_<:HiveDesc], tableName: String, - storageLevel: StorageLevel, + databaseName: String, numColumns: Int, - useTachyon: Boolean, - useUnionRDD: Boolean): TerminalOperator = { + hivePartitionKeyOpt: Option[String], + cacheMode: CacheType, + isInsertInto: Boolean): TerminalOperator = { + // TODO the terminal operator is the FileSinkOperator in Hive? + val hiveOp = hiveTerminalOp.asInstanceOf[org.apache.hadoop.hive.ql.exec.FileSinkOperator] val sinkOp = _newOperatorInstance( - classOf[MemoryStoreSinkOperator], hiveTerminalOp).asInstanceOf[MemoryStoreSinkOperator] + classOf[MemoryStoreSinkOperator], hiveOp).asInstanceOf[MemoryStoreSinkOperator] + sinkOp.localHiveOp = hiveOp sinkOp.tableName = tableName - sinkOp.storageLevel = storageLevel + sinkOp.databaseName = databaseName sinkOp.numColumns = numColumns - sinkOp.useTachyon = useTachyon - sinkOp.useUnionRDD = useUnionRDD + sinkOp.cacheMode = cacheMode + sinkOp.hivePartitionKeyOpt = hivePartitionKeyOpt + sinkOp.isInsertInto = isInsertInto _createAndSetParents(sinkOp, hiveTerminalOp.getParentOperators).asInstanceOf[TerminalOperator] } - def createSharkFileOutputPlan(hiveTerminalOp: HiveOperator): TerminalOperator = { - val sinkOp = _newOperatorInstance(classOf[FileSinkOperator], hiveTerminalOp) - _createAndSetParents( - sinkOp, hiveTerminalOp.getParentOperators).asInstanceOf[TerminalOperator] + def createSharkFileOutputPlan(hiveTerminalOp: HOperator[_<:HiveDesc]): TerminalOperator = { + // TODO the terminal operator is the FileSinkOperator in Hive? + val hiveOp = hiveTerminalOp.asInstanceOf[org.apache.hadoop.hive.ql.exec.FileSinkOperator] + val sinkOp = _newOperatorInstance(classOf[FileSinkOperator], + hiveOp).asInstanceOf[TerminalOperator] + sinkOp.localHiveOp = hiveOp + _createAndSetParents(sinkOp, hiveTerminalOp.getParentOperators).asInstanceOf[TerminalOperator] } - def createSharkRddOutputPlan(hiveTerminalOp: HiveOperator): TerminalOperator = { - val sinkOp = _newOperatorInstance(classOf[TableRddSinkOperator], hiveTerminalOp) + def createSharkRddOutputPlan(hiveTerminalOp: HOperator[_<:HiveDesc]): TerminalOperator = { + // TODO the terminal operator is the FileSinkOperator in Hive? + val hiveOp = hiveTerminalOp.asInstanceOf[org.apache.hadoop.hive.ql.exec.FileSinkOperator] + val sinkOp = _newOperatorInstance(classOf[TableRddSinkOperator], + hiveOp).asInstanceOf[TableRddSinkOperator] + sinkOp.localHiveOp = hiveOp _createAndSetParents(sinkOp, hiveTerminalOp.getParentOperators).asInstanceOf[TerminalOperator] } /** Create a Shark operator given the Hive operator. */ - private def createSingleOperator(hiveOp: HiveOperator): Operator[_] = { + private def createSingleOperator[T<:HiveDesc](hiveOp: HOperator[T]): Operator[T] = { // This is kind of annoying, but it works with strong typing ... val sharkOp = hiveOp match { case hop: org.apache.hadoop.hive.ql.exec.TableScanOperator => - _newOperatorInstance(classOf[TableScanOperator], hiveOp) + val op = _newOperatorInstance(classOf[TableScanOperator], hop) + op.asInstanceOf[TableScanOperator].hiveOp = hop + op case hop: org.apache.hadoop.hive.ql.exec.SelectOperator => - _newOperatorInstance(classOf[SelectOperator], hiveOp) + _newOperatorInstance(classOf[SelectOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.FileSinkOperator => - _newOperatorInstance(classOf[TerminalOperator], hiveOp) + val op = _newOperatorInstance(classOf[TerminalOperator], hop) + op.asInstanceOf[TerminalOperator].localHiveOp = hop + op case hop: org.apache.hadoop.hive.ql.exec.LimitOperator => - _newOperatorInstance(classOf[LimitOperator], hiveOp) + _newOperatorInstance(classOf[LimitOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.FilterOperator => - _newOperatorInstance(classOf[FilterOperator], hiveOp) + _newOperatorInstance(classOf[FilterOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.SamplingOperator => - _newOperatorInstance(classOf[SamplingOperator], hiveOp) + _newOperatorInstance(classOf[SamplingOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.ReduceSinkOperator => - _newOperatorInstance(classOf[ReduceSinkOperator], hiveOp) + _newOperatorInstance(classOf[ReduceSinkOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.ExtractOperator => - _newOperatorInstance(classOf[ExtractOperator], hiveOp) + _newOperatorInstance(classOf[ExtractOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.UnionOperator => - _newOperatorInstance(classOf[UnionOperator], hiveOp) + _newOperatorInstance(classOf[UnionOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.JoinOperator => - _newOperatorInstance(classOf[JoinOperator], hiveOp) + _newOperatorInstance(classOf[JoinOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.MapJoinOperator => - _newOperatorInstance(classOf[MapJoinOperator], hiveOp) + _newOperatorInstance(classOf[MapJoinOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.ScriptOperator => - _newOperatorInstance(classOf[ScriptOperator], hiveOp) + val op = _newOperatorInstance(classOf[ScriptOperator], hop) + op.asInstanceOf[ScriptOperator].operatorId = hop.getOperatorId() + op case hop: org.apache.hadoop.hive.ql.exec.LateralViewForwardOperator => - _newOperatorInstance(classOf[LateralViewForwardOperator], hiveOp) + _newOperatorInstance(classOf[LateralViewForwardOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.LateralViewJoinOperator => - _newOperatorInstance(classOf[LateralViewJoinOperator], hiveOp) + _newOperatorInstance(classOf[LateralViewJoinOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.UDTFOperator => - _newOperatorInstance(classOf[UDTFOperator], hiveOp) + _newOperatorInstance(classOf[UDTFOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.ForwardOperator => - _newOperatorInstance(classOf[ForwardOperator], hiveOp) + _newOperatorInstance(classOf[ForwardOperator], hop) case hop: org.apache.hadoop.hive.ql.exec.GroupByOperator => { // For GroupBy, we separate post shuffle from pre shuffle. if (GroupByOperator.isPostShuffle(hop)) { - _newOperatorInstance(classOf[GroupByPostShuffleOperator], hiveOp) + _newOperatorInstance(classOf[GroupByPostShuffleOperator], hop) } else { - _newOperatorInstance(classOf[GroupByPreShuffleOperator], hiveOp) + _newOperatorInstance(classOf[GroupByPreShuffleOperator], hop) } } case _ => throw new HiveException("Unsupported Hive operator: " + hiveOp.getClass.getName) } logDebug("Replacing %s with %s".format(hiveOp.getClass.getName, sharkOp.getClass.getName)) - sharkOp + sharkOp.asInstanceOf[Operator[T]] } - private def _newOperatorInstance[T <: HiveOperator]( - cls: Class[_ <: Operator[T]], hiveOp: HiveOperator): Operator[_] = { + private def _newOperatorInstance[T <: HiveDesc]( + cls: Class[_ <: Operator[T]], hiveOp: HOperator[T]): Operator[T] = { val op = cls.newInstance() - op.hiveOp = hiveOp.asInstanceOf[T] + op.setDesc(hiveOp.getConf()) op } - private def _createAndSetParents(op: Operator[_], parents: Seq[HiveOperator]) = { + private def _createAndSetParents[T <: HiveDesc](op: Operator[T], + parents: Seq[HOperator[_<:HiveDesc]]) = { if (parents != null) { parents foreach { parent => _createOperatorTree(parent).addChild(op) @@ -141,7 +160,7 @@ object OperatorFactory extends LogHelper { * Given a terminal operator in Hive, create the plan that uses Shark physical * operators. */ - private def _createOperatorTree(hiveOp: HiveOperator): Operator[_] = { + private def _createOperatorTree[T<:HiveDesc](hiveOp: HOperator[T]): Operator[T] = { val current = createSingleOperator(hiveOp) val parents = hiveOp.getParentOperators if (parents != null) { @@ -152,7 +171,7 @@ object OperatorFactory extends LogHelper { } } - private def _findHiveTerminalOperator(hiveOp: HiveOperator): HiveOperator = { + private def _findHiveTerminalOperator(hiveOp: HOperator[_<:HiveDesc]): HOperator[_<:HiveDesc] = { if (hiveOp.getChildOperators() == null || hiveOp.getChildOperators().size() == 0) { hiveOp } else { diff --git a/src/main/scala/shark/execution/RDDUtils.scala b/src/main/scala/shark/execution/RDDUtils.scala index bd68890b..1306f9bd 100755 --- a/src/main/scala/shark/execution/RDDUtils.scala +++ b/src/main/scala/shark/execution/RDDUtils.scala @@ -18,12 +18,12 @@ package shark.execution import scala.collection.JavaConversions +import scala.reflect.ClassTag import com.google.common.collect.{Ordering => GOrdering} import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner} import org.apache.spark.rdd.{RDD, ShuffledRDD, UnionRDD} -import org.apache.spark.storage.StorageLevel import shark.SharkEnv @@ -34,26 +34,46 @@ import shark.SharkEnv */ object RDDUtils { - def getStorageLevelOfCachedTable(rdd: RDD[_]): StorageLevel = { + /** + * Returns a UnionRDD using both RDD arguments. Any UnionRDD argument is "flattened", in that + * its parent sequence of RDDs is directly passed to the UnionRDD returned. + */ + def unionAndFlatten[T: ClassTag]( + rdd: RDD[T], + otherRdd: RDD[T]): UnionRDD[T] = { + val otherRdds: Seq[RDD[T]] = otherRdd match { + case otherUnionRdd: UnionRDD[_] => otherUnionRdd.rdds + case _ => Seq(otherRdd) + } + val rdds: Seq[RDD[T]] = rdd match { + case unionRdd: UnionRDD[_] => unionRdd.rdds + case _ => Seq(rdd) + } + new UnionRDD(rdd.context, rdds ++ otherRdds) + } + + def unpersistRDD(rdd: RDD[_]): RDD[_] = { rdd match { - case u: UnionRDD[_] => u.rdds.foldLeft(rdd.getStorageLevel) { - (s, r) => { - if (s == StorageLevel.NONE) { - getStorageLevelOfCachedTable(r) - } else { - s - } + case u: UnionRDD[_] => { + // Usually, a UnionRDD will not be persisted to avoid data duplication. + u.unpersist() + // unpersist() all parent RDDs that compose the UnionRDD. Don't propagate past the parents, + // since a grandparent of the UnionRDD might have multiple child RDDs (i.e., the sibling of + // the UnionRDD's parent is persisted in memory). + u.rdds.map { + r => r.unpersist() } } - case _ => rdd.getStorageLevel + case r => r.unpersist() } + return rdd } /** * Repartition an RDD using the given partitioner. This is similar to Spark's partitionBy, * except we use the Shark shuffle serializer. */ - def repartition[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)], part: Partitioner) + def repartition[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], part: Partitioner) : RDD[(K, V)] = { new ShuffledRDD[K, V, (K, V)](rdd, part).setSerializer(SharkEnv.shuffleSerializerName) @@ -63,7 +83,7 @@ object RDDUtils { * Sort the RDD by key. This is similar to Spark's sortByKey, except that we use * the Shark shuffle serializer. */ - def sortByKey[K <: Comparable[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) + def sortByKey[K <: Comparable[K]: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) : RDD[(K, V)] = { val part = new RangePartitioner(rdd.partitions.length, rdd) @@ -72,13 +92,13 @@ object RDDUtils { shuffled.mapPartitions(iter => { val buf = iter.toArray buf.sortWith((x, y) => x._1.compareTo(y._1) < 0).iterator - }, true) + }, preservesPartitioning = true) } /** * Return an RDD containing the top K (K smallest key) from the given RDD. */ - def topK[K <: Comparable[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)], k: Int) + def topK[K <: Comparable[K]: ClassTag, V: ClassTag](rdd: RDD[(K, V)], k: Int) : RDD[(K, V)] = { // First take top K on each partition. @@ -90,7 +110,7 @@ object RDDUtils { /** * Take top K on each partition and return a new RDD. */ - def partitionTopK[K <: Comparable[K]: ClassManifest, V: ClassManifest]( + def partitionTopK[K <: Comparable[K]: ClassTag, V: ClassTag]( rdd: RDD[(K, V)], k: Int): RDD[(K, V)] = { rdd.mapPartitions(iter => topK(iter, k)) } @@ -98,7 +118,7 @@ object RDDUtils { /** * Return top K elements out of an iterator. */ - private def topK[K <: Comparable[K]: ClassManifest, V: ClassManifest]( + private def topK[K <: Comparable[K]: ClassTag, V: ClassTag]( it: Iterator[(K, V)], k: Int): Iterator[(K, V)] = { val ordering = new GOrdering[(K,V)] { override def compare(l: (K, V), r: (K, V)) = { diff --git a/src/main/scala/shark/execution/ReduceKey.scala b/src/main/scala/shark/execution/ReduceKey.scala index d91df3bc..e2436b19 100755 --- a/src/main/scala/shark/execution/ReduceKey.scala +++ b/src/main/scala/shark/execution/ReduceKey.scala @@ -79,7 +79,8 @@ class ReduceKeyMapSide(var bytesWritable: BytesWritable) extends ReduceKey if (length != other.length) { false } else { - WritableComparator.compareBytes(byteArray, 0, length, other.byteArray, 0, other.length) == 0 + WritableComparator.compareBytes( + byteArray, 0, length, other.byteArray, 0, other.length) == 0 } } case _ => false @@ -116,10 +117,12 @@ class ReduceKeyReduceSide(private val _byteArray: Array[Byte]) extends ReduceKey override def length: Int = byteArray.length override def equals(other: Any): Boolean = { - // We expect this is only used in a hash table comparing to the same types. - // So we force a type cast. - val that = other.asInstanceOf[ReduceKeyReduceSide] - (this.byteArray.length == that.byteArray.length && this.compareTo(that) == 0) + other match { + case that: ReduceKeyReduceSide => { + (this.byteArray.length == that.byteArray.length) && (this.compareTo(that) == 0) + } + case _ => false + } } override def compareTo(that: ReduceKey): Int = { diff --git a/src/main/scala/shark/execution/ReduceSinkOperator.scala b/src/main/scala/shark/execution/ReduceSinkOperator.scala index 6c01ab47..619544a3 100755 --- a/src/main/scala/shark/execution/ReduceSinkOperator.scala +++ b/src/main/scala/shark/execution/ReduceSinkOperator.scala @@ -23,13 +23,12 @@ import scala.collection.Iterator import scala.collection.JavaConversions._ import scala.reflect.BeanProperty -import org.apache.hadoop.hive.ql.exec.{ReduceSinkOperator => HiveReduceSinkOperator} import org.apache.hadoop.hive.ql.exec.{ExprNodeEvaluator, ExprNodeEvaluatorFactory} -import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc import org.apache.hadoop.hive.serde2.SerDe -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, - ObjectInspectorUtils} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector.StandardUnion import org.apache.hadoop.io.BytesWritable @@ -38,7 +37,7 @@ import org.apache.hadoop.io.BytesWritable * Converts a collection of rows into key, value pairs. This is the * upstream operator for joins and groupbys. */ -class ReduceSinkOperator extends UnaryOperator[HiveReduceSinkOperator] { +class ReduceSinkOperator extends UnaryOperator[ReduceSinkDesc] { @BeanProperty var conf: ReduceSinkDesc = _ @@ -58,16 +57,22 @@ class ReduceSinkOperator extends UnaryOperator[HiveReduceSinkOperator] { @transient var keySer: SerDe = _ @transient var valueSer: SerDe = _ @transient var keyObjInspector: ObjectInspector = _ + @transient var keyFieldObjInspectors: Array[ObjectInspector] = _ @transient var valObjInspector: ObjectInspector = _ + @transient var valFieldObjInspectors: Array[ObjectInspector] = _ @transient var partitionObjInspectors: Array[ObjectInspector] = _ override def getTag() = conf.getTag() override def initializeOnMaster() { - conf = hiveOp.getConf() + super.initializeOnMaster() + + conf = desc } override def initializeOnSlave() { + super.initializeOnSlave() + initializeOisAndSers(conf, objectInspector) } @@ -79,45 +84,26 @@ class ReduceSinkOperator extends UnaryOperator[HiveReduceSinkOperator] { } } - def initializeDownStreamHiveOperator() { - - conf = hiveOp.getConf() - - // Note that we get input object inspector from hiveOp rather than Shark's - // objectInspector because initializeMasterOnAll() hasn't been invoked yet. - initializeOisAndSers(conf, hiveOp.getInputObjInspectors().head) - - // Determine output object inspector (a struct of KEY, VALUE). + override def outputObjectInspector() = { + initializeOisAndSers(conf, objectInspector) + val ois = new ArrayList[ObjectInspector] ois.add(keySer.getObjectInspector) ois.add(valueSer.getObjectInspector) - - val outputObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector(List("KEY","VALUE"), ois) - - val joinTag = conf.getTag() - - // Propagate the output object inspector and serde infos to downstream operator. - childOperators.foreach { child => - child match { - case child: HiveTopOperator => { - child.setInputObjectInspector(joinTag, outputObjInspector) - child.setKeyValueTableDescs(joinTag, - (conf.getKeySerializeInfo, conf.getValueSerializeInfo)) - } - case _ => { - throw new HiveException("%s's downstream operator should be %s. %s found.".format( - this.getClass.getName, classOf[HiveTopOperator].getName, child.getClass.getName)) - } - } - } + ObjectInspectorFactory.getStandardStructObjectInspector(List("KEY", "VALUE"), ois) } + // will be used of the children operators (in JoinOperator/Extractor/GroupByPostShuffleOperator + def getKeyValueTableDescs() = (conf.getKeySerializeInfo, conf.getValueSerializeInfo) + /** * Initialize the object inspectors, evaluators, and serializers. Used on * both the master and the slave. */ private def initializeOisAndSers(conf: ReduceSinkDesc, rowInspector: ObjectInspector) { keyEval = conf.getKeyCols.map(ExprNodeEvaluatorFactory.get(_)).toArray + keyFieldObjInspectors = initEvaluators(keyEval, 0, keyEval.length, rowInspector) + val numDistributionKeys = conf.getNumDistributionKeys() val distinctColIndices = conf.getDistinctColumnIndices() valueEval = conf.getValueCols.map(ExprNodeEvaluatorFactory.get(_)).toArray @@ -133,17 +119,17 @@ class ReduceSinkOperator extends UnaryOperator[HiveReduceSinkOperator] { valueSer.initialize(null, valueTableDesc.getProperties()) // Initialize object inspector for key columns. - keyObjInspector = ReduceSinkOperatorHelper.initEvaluatorsAndReturnStruct( - keyEval, - distinctColIndices, - conf.getOutputKeyColumnNames, - numDistributionKeys, - rowInspector) + keyObjInspector = initEvaluatorsAndReturnStruct( + keyEval, + distinctColIndices, + conf.getOutputKeyColumnNames, + numDistributionKeys, + rowInspector) // Initialize object inspector for value columns. - val valFieldInspectors = valueEval.map(eval => eval.initialize(rowInspector)).toList + valFieldObjInspectors = valueEval.map(eval => eval.initialize(rowInspector)) valObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector( - conf.getOutputValueColumnNames(), valFieldInspectors) + conf.getOutputValueColumnNames(), valFieldObjInspectors.toList) // Initialize evaluator and object inspector for partition columns. partitionEval = conf.getPartitionCols.map(ExprNodeEvaluatorFactory.get(_)).toArray diff --git a/src/main/scala/shark/execution/ReduceSinkTableDesc.scala b/src/main/scala/shark/execution/ReduceSinkTableDesc.scala new file mode 100644 index 00000000..5acefae7 --- /dev/null +++ b/src/main/scala/shark/execution/ReduceSinkTableDesc.scala @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.execution + +import org.apache.hadoop.hive.ql.plan.TableDesc +import shark.LogHelper + + +trait ReduceSinkTableDesc extends LogHelper { + self: Operator[_ <: HiveDesc] => + + // Seq(tag, (Key TableDesc, Value TableDesc)) + def keyValueDescs(): Seq[(Int, (TableDesc, TableDesc))] = { + // get the parent ReduceSinkOperator and sort it by tag + val reduceSinkOps = + for (op <- self.parentOperators.toSeq if op.isInstanceOf[ReduceSinkOperator]) + yield op.asInstanceOf[ReduceSinkOperator] + + reduceSinkOps.map(f => (f.getTag, f.getKeyValueTableDescs)) + } +} diff --git a/src/main/scala/shark/execution/ScriptOperator.scala b/src/main/scala/shark/execution/ScriptOperator.scala index f8b56c82..14203af4 100755 --- a/src/main/scala/shark/execution/ScriptOperator.scala +++ b/src/main/scala/shark/execution/ScriptOperator.scala @@ -17,8 +17,9 @@ package shark.execution -import java.io.{File, InputStream} -import java.util.{Arrays, Properties} +import java.io.{File, InputStream, IOException} +import java.lang.Thread.UncaughtExceptionHandler +import java.util.Properties import scala.collection.JavaConversions._ import scala.io.Source @@ -26,17 +27,19 @@ import scala.reflect.BeanProperty import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} import org.apache.hadoop.hive.ql.exec.{ScriptOperator => HiveScriptOperator} -import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter, ScriptOperatorHelper} +import org.apache.hadoop.hive.ql.exec.{ScriptOperatorHelper => HiveScriptOperatorHelper} import org.apache.hadoop.hive.ql.plan.ScriptDesc import org.apache.hadoop.hive.serde2.{Serializer, Deserializer} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector import org.apache.hadoop.io.{BytesWritable, Writable} -import org.apache.spark.{OneToOneDependency, SparkEnv, SparkFiles} +import org.apache.spark.{SparkEnv, SparkFiles} import org.apache.spark.rdd.RDD +import org.apache.spark.TaskContext import shark.execution.serialization.OperatorSerializationWrapper +import shark.LogHelper /** @@ -44,12 +47,13 @@ import shark.execution.serialization.OperatorSerializationWrapper * * Example: select transform(key) using 'cat' as cola from src; */ -class ScriptOperator extends UnaryOperator[HiveScriptOperator] { +class ScriptOperator extends UnaryOperator[ScriptDesc] { - @BeanProperty var localHiveOp: HiveScriptOperator = _ @BeanProperty var localHconf: HiveConf = _ @BeanProperty var alias: String = _ + @BeanProperty var conf: ScriptDesc = _ + @transient var operatorId: String = _ @transient var scriptInputSerializer: Serializer = _ @transient var scriptOutputDeserializer: Deserializer = _ @@ -62,12 +66,12 @@ class ScriptOperator extends UnaryOperator[HiveScriptOperator] { val op = OperatorSerializationWrapper(this) val (command, envs) = getCommandAndEnvs() - val outRecordReaderClass: Class[_ <: RecordReader] = hiveOp.getConf().getOutRecordReaderClass() - val inRecordWriterClass: Class[_ <: RecordWriter] = hiveOp.getConf().getInRecordWriterClass() - logInfo("Using %s and %s".format(outRecordReaderClass, inRecordWriterClass)) + val outRecordReaderClass: Class[_ <: RecordReader] = conf.getOutRecordReaderClass() + val inRecordWriterClass: Class[_ <: RecordWriter] = conf.getInRecordWriterClass() + logDebug("Using %s and %s".format(outRecordReaderClass, inRecordWriterClass)) // Deserialize the output from script back to what Hive understands. - inputRdd.mapPartitions { part => + inputRdd.mapPartitionsWithContext { (context, part) => op.initializeOnSlave() // Serialize the data so it is recognizable by the script. @@ -97,28 +101,37 @@ class ScriptOperator extends UnaryOperator[HiveScriptOperator] { // Get the thread local SparkEnv so we can pass it into the new thread. val sparkEnv = SparkEnv.get + // If true, exceptions thrown by child threads will be ignored. + val allowPartialConsumption = op.allowPartialConsumption + // Start a thread to print the process's stderr to ours - new Thread("stderr reader for " + command) { + val errorReaderThread = new Thread("stderr reader for " + command) { override def run() { - for(line <- Source.fromInputStream(proc.getErrorStream).getLines) { + for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { System.err.println(line) } } - }.start() + } + errorReaderThread.setUncaughtExceptionHandler( + new ScriptOperator.ScriptExceptionHandler(allowPartialConsumption, context)) + errorReaderThread.start() // Start a thread to feed the process input from our parent's iterator - new Thread("stdin writer for " + command) { + val inputWriterThread = new Thread("stdin writer for " + command) { override def run() { // Set the thread local SparkEnv. SparkEnv.set(sparkEnv) val recordWriter = inRecordWriterClass.newInstance recordWriter.initialize(proc.getOutputStream, op.localHconf) - for(elem <- iter) { + for (elem <- iter) { recordWriter.write(elem) } recordWriter.close() } - }.start() + } + inputWriterThread.setUncaughtExceptionHandler( + new ScriptOperator.ScriptExceptionHandler(allowPartialConsumption, context)) + inputWriterThread.start() // Return an iterator that reads outputs from RecordReader. Use our own // BinaryRecordReader if necessary because Hive's has a bug (see below). @@ -131,31 +144,29 @@ class ScriptOperator extends UnaryOperator[HiveScriptOperator] { recordReader.initialize( proc.getInputStream, op.localHconf, - op.localHiveOp.getConf().getScriptOutputInfo().getProperties()) + op.conf.getScriptOutputInfo().getProperties()) op.deserializeFromScript(new ScriptOperator.RecordReaderIterator(recordReader)) } } override def initializeOnMaster() { - localHiveOp = hiveOp + super.initializeOnMaster() localHconf = super.hconf - // Set parent to null so we won't serialize the entire query plan. - hiveOp.setParentOperators(null) - hiveOp.setChildOperators(null) - hiveOp.setInputObjInspectors(null) + conf = desc + + initializeOnSlave() } + override def outputObjectInspector() = scriptOutputDeserializer.getObjectInspector() + override def initializeOnSlave() { - scriptOutputDeserializer = localHiveOp.getConf().getScriptOutputInfo() - .getDeserializerClass().newInstance() - scriptOutputDeserializer.initialize(localHconf, localHiveOp.getConf() - .getScriptOutputInfo().getProperties()) - - scriptInputSerializer = localHiveOp.getConf().getScriptInputInfo().getDeserializerClass() - .newInstance().asInstanceOf[Serializer] - scriptInputSerializer.initialize( - localHconf, localHiveOp.getConf().getScriptInputInfo().getProperties()) + scriptOutputDeserializer = conf.getScriptOutputInfo().getDeserializerClass().newInstance() + scriptOutputDeserializer.initialize(localHconf, conf.getScriptOutputInfo().getProperties()) + + scriptInputSerializer = conf.getScriptInputInfo().getDeserializerClass() + .newInstance().asInstanceOf[Serializer] + scriptInputSerializer.initialize(localHconf, conf.getScriptInputInfo().getProperties()) } /** @@ -164,17 +175,17 @@ class ScriptOperator extends UnaryOperator[HiveScriptOperator] { */ def getCommandAndEnvs(): (Seq[String], Map[String, String]) = { - val scriptOpHelper = new ScriptOperatorHelper(new HiveScriptOperator) + val scriptOpHelper = new HiveScriptOperatorHelper(new HiveScriptOperator) alias = scriptOpHelper.getAlias - val cmdArgs = HiveScriptOperator.splitArgs(hiveOp.getConf().getScriptCmd()) + val cmdArgs = HiveScriptOperator.splitArgs(conf.getScriptCmd()) val prog = cmdArgs(0) val currentDir = new File(".").getAbsoluteFile() if (!(new File(prog)).isAbsolute()) { val finder = scriptOpHelper.newPathFinderInstance("PATH") finder.prependPathComponent(currentDir.toString()) - var f = finder.getAbsolutePath(prog) + val f = finder.getAbsolutePath(prog) if (f != null) { cmdArgs(0) = f.getAbsolutePath() } @@ -191,12 +202,12 @@ class ScriptOperator extends UnaryOperator[HiveScriptOperator] { scriptOpHelper.addJobConfToEnvironment(hconf, envs) envs.put(scriptOpHelper.safeEnvVarName(HiveConf.ConfVars.HIVEALIAS.varname), - String.valueOf(alias)) + String.valueOf(alias)) // Create an environment variable that uniquely identifies this script // operator val idEnvVarName = HiveConf.getVar(hconf, HiveConf.ConfVars.HIVESCRIPTIDENVVAR) - val idEnvVarVal = hiveOp.getOperatorId() + val idEnvVarVal = operatorId envs.put(scriptOpHelper.safeEnvVarName(idEnvVarName), idEnvVarVal) (wrappedCmdArgs, Map.empty ++ envs) @@ -218,6 +229,9 @@ class ScriptOperator extends UnaryOperator[HiveScriptOperator] { } } + def allowPartialConsumption: Boolean = + HiveConf.getBoolVar(localHconf, HiveConf.ConfVars.ALLOWPARTIALCONSUMP) + def serializeForScript[T](iter: Iterator[T]): Iterator[Writable] = iter.map { row => scriptInputSerializer.serialize(row, objectInspector) } @@ -227,6 +241,47 @@ class ScriptOperator extends UnaryOperator[HiveScriptOperator] { object ScriptOperator { + /** + * A general exception handler to attach to child threads used to feed input rows and forward + * errors to the parent thread during ScriptOperator#execute(). + * If partial query consumption is not allowed (see HiveConf.Confvars.ALLOWPARTIALCONSUMP), then + * exceptions from child threads are caught by the handler and re-thrown by the parent thread + * through an on-task-completion callback registered with the Spark TaskContext. The task will be + * marked "failed" and the exception will be propagated to the master/CLI. + */ + class ScriptExceptionHandler(allowPartialConsumption: Boolean, context: TaskContext) + extends UncaughtExceptionHandler + with LogHelper { + + override def uncaughtException(thread: Thread, throwable: Throwable) { + throwable match { + case ioe: IOException => { + // Check whether the IOException should be re-thrown by the parent thread. + if (allowPartialConsumption) { + logWarning("Error while executing script. Ignoring %s" + .format(ioe.getMessage)) + } else { + val onCompleteCallback = () => { + logWarning("Error during script execution. Set %s=true to ignore thrown IOExceptions." + .format(HiveConf.ConfVars.ALLOWPARTIALCONSUMP.toString)) + throw ioe + } + context.synchronized { + context.addOnCompleteCallback(onCompleteCallback) + } + } + } + case _ => { + // Throw any other Exceptions or Errors. + val onCompleteCallback = () => throw throwable + context.synchronized { + context.addOnCompleteCallback(onCompleteCallback) + } + } + } + } + } + /** * An iterator that wraps around a Hive RecordReader. */ @@ -285,7 +340,7 @@ object ScriptOperator { if (recordLength >= 0) { bytesWritable.setSize(recordLength) } - return recordLength; + return recordLength } override def close() { if (in != null) { in.close() } } diff --git a/src/main/scala/shark/execution/SelectOperator.scala b/src/main/scala/shark/execution/SelectOperator.scala index 3361cfee..bcca156a 100755 --- a/src/main/scala/shark/execution/SelectOperator.scala +++ b/src/main/scala/shark/execution/SelectOperator.scala @@ -21,31 +21,39 @@ import scala.collection.JavaConversions._ import scala.reflect.BeanProperty import org.apache.hadoop.hive.ql.exec.{ExprNodeEvaluator, ExprNodeEvaluatorFactory} -import org.apache.hadoop.hive.ql.exec.{SelectOperator => HiveSelectOperator} import org.apache.hadoop.hive.ql.plan.SelectDesc +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector /** * An operator that does projection, i.e. selecting certain columns and * filtering out others. */ -class SelectOperator extends UnaryOperator[HiveSelectOperator] { +class SelectOperator extends UnaryOperator[SelectDesc] { @BeanProperty var conf: SelectDesc = _ @transient var evals: Array[ExprNodeEvaluator] = _ override def initializeOnMaster() { - conf = hiveOp.getConf() + super.initializeOnMaster() + conf = desc + initializeEvals(false) } - - override def initializeOnSlave() { + + def initializeEvals(initializeEval: Boolean) { if (!conf.isSelStarNoCompute) { evals = conf.getColList().map(ExprNodeEvaluatorFactory.get(_)).toArray - evals.foreach(_.initialize(objectInspector)) + if (initializeEval) { + evals.foreach(_.initialize(objectInspector)) + } } } + override def initializeOnSlave() { + initializeEvals(true) + } + override def processPartition(split: Int, iter: Iterator[_]) = { if (conf.isSelStarNoCompute) { iter @@ -61,4 +69,12 @@ class SelectOperator extends UnaryOperator[HiveSelectOperator] { } } } + + override def outputObjectInspector(): ObjectInspector = { + if (conf.isSelStarNoCompute()) { + super.outputObjectInspector() + } else { + initEvaluatorsAndReturnStruct(evals, conf.getOutputColumnNames(), objectInspector) + } + } } diff --git a/src/main/scala/shark/execution/SharkDDLTask.scala b/src/main/scala/shark/execution/SharkDDLTask.scala new file mode 100644 index 00000000..21707ed3 --- /dev/null +++ b/src/main/scala/shark/execution/SharkDDLTask.scala @@ -0,0 +1,237 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.execution + +import java.util.{List => JavaList, Map => JavaMap} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.ql.{Context, DriverContext} +import org.apache.hadoop.hive.ql.exec.{Task => HiveTask} +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.plan._ +import org.apache.hadoop.hive.ql.plan.api.StageType + +import org.apache.spark.rdd.EmptyRDD + +import shark.{LogHelper, SharkEnv} +import shark.memstore2.{CacheType, MemoryTable, MemoryMetadataManager, PartitionedMemoryTable} +import shark.memstore2.{SharkTblProperties, TablePartitionStats} +import shark.util.HiveUtils + + +private[shark] class SharkDDLWork(val ddlDesc: DDLDesc) extends java.io.Serializable { + + var cacheMode: CacheType.CacheType = _ + +} + +/** + * A task used for Shark-specific metastore operations needed for DDL commands, in addition to the + * metastore updates done by Hive's DDLTask. + * + * Validity checks for DDL commands, such as whether a target table for a CREATE TABLE command + * already exists, is not done by SharkDDLTask. Instead, the SharkDDLTask is meant to be used as + * a dependent task of Hive's DDLTask, which handles all error checking. This way, a SharkDDLTask + * is executed only if the Hive DDLTask is successfully executed - i.e., the DDL statement is a + * valid one. + */ +private[shark] class SharkDDLTask extends HiveTask[SharkDDLWork] + with Serializable with LogHelper { + + override def execute(driverContext: DriverContext): Int = { + val hiveDb = Hive.get(conf) + + // TODO(harvey): Check whether the `hiveDb` is needed. HiveTask should already have a `db` to + // use. + work.ddlDesc match { + case creatTblDesc: CreateTableDesc => createTable(hiveDb, creatTblDesc, work.cacheMode) + case addPartitionDesc: AddPartitionDesc => addPartition(hiveDb, addPartitionDesc, work.cacheMode) + case dropTableDesc: DropTableDesc => dropTableOrPartition(hiveDb, dropTableDesc, work.cacheMode) + case alterTableDesc: AlterTableDesc => alterTable(hiveDb, alterTableDesc, work.cacheMode) + case _ => { + throw new UnsupportedOperationException( + "Shark does not require a Shark DDL task for: " + work.ddlDesc.getClass.getName) + } + } + + // Hive's task runner expects a '0' return value to indicate success, and an exception + // otherwise + return 0 + } + + /** + * Updates Shark metastore for a CREATE TABLE or CTAS command. + * + * @param hiveMetadataDb Namespace of the table to create. + * @param createTblDesc Hive metadata object that contains fields needed to create a Shark Table + * entry. + * @param cacheMode How the created table should be stored and maintained (e.g, MEMORY means that + * table data will be in memory and persistent across Shark sessions). + */ + def createTable( + hiveMetadataDb: Hive, + createTblDesc: CreateTableDesc, + cacheMode: CacheType.CacheType) { + val dbName = hiveMetadataDb.getCurrentDatabase + val tableName = createTblDesc.getTableName + val tblProps = createTblDesc.getTblProps + + if (cacheMode == CacheType.TACHYON) { + // For Tachyon tables (partitioned or not), just create the parent directory. + SharkEnv.tachyonUtil.createDirectory( + MemoryMetadataManager.makeTableKey(dbName, tableName), hivePartitionKeyOpt = None) + } else { + val isHivePartitioned = (createTblDesc.getPartCols.size > 0) + if (isHivePartitioned) { + // Add a new PartitionedMemoryTable entry in the Shark metastore. + // An empty table has a PartitionedMemoryTable entry with no 'hivePartition -> RDD' mappings. + SharkEnv.memoryMetadataManager.createPartitionedMemoryTable( + dbName, tableName, cacheMode, tblProps) + } else { + val memoryTable = SharkEnv.memoryMetadataManager.createMemoryTable( + dbName, tableName, cacheMode) + // An empty table has a MemoryTable table entry with 'tableRDD' referencing an EmptyRDD. + memoryTable.put(new EmptyRDD(SharkEnv.sc)) + } + } + } + + /** + * Updates Shark metastore for an ALTER TABLE ADD PARTITION command. + * + * @param hiveMetadataDb Namespace of the table to update. + * @param addPartitionDesc Hive metadata object that contains fields about the new partition. + */ + def addPartition( + hiveMetadataDb: Hive, + addPartitionDesc: AddPartitionDesc, + cacheMode: CacheType.CacheType) { + val dbName = hiveMetadataDb.getCurrentDatabase() + val tableName = addPartitionDesc.getTableName + + // Find the set of partition column values that specifies the partition being added. + val hiveTable = db.getTable(tableName, false /* throwException */); + val partCols: Seq[String] = hiveTable.getPartCols.map(_.getName) + val partColToValue: JavaMap[String, String] = addPartitionDesc.getPartSpec + // String format for partition key: 'col1=value1/col2=value2/...' + val partKeyStr: String = MemoryMetadataManager.makeHivePartitionKeyStr(partCols, partColToValue) + if (cacheMode == CacheType.TACHYON) { + SharkEnv.tachyonUtil.createDirectory( + MemoryMetadataManager.makeTableKey(dbName, tableName), Some(partKeyStr)) + } else { + val partitionedTable = getPartitionedTableWithAssertions(dbName, tableName) + partitionedTable.putPartition(partKeyStr, new EmptyRDD(SharkEnv.sc)) + } + } + + /** + * Updates Shark metastore when dropping a table or partition. + * + * @param hiveMetadataDb Namespace of the table to drop, or the table that a partition belongs to. + * @param dropTableDesc Hive metadata object used for both dropping entire tables + * (i.e., DROP TABLE) and for dropping individual partitions of a table + * (i.e., ALTER TABLE DROP PARTITION). + */ + def dropTableOrPartition( + hiveMetadataDb: Hive, + dropTableDesc: DropTableDesc, + cacheMode: CacheType.CacheType) { + val dbName = hiveMetadataDb.getCurrentDatabase() + val tableName = dropTableDesc.getTableName + val hiveTable = db.getTable(tableName, false /* throwException */); + val partSpecs: JavaList[PartitionSpec] = dropTableDesc.getPartSpecs + val tableKey = MemoryMetadataManager.makeTableKey(dbName, tableName) + + if (partSpecs == null) { + // The command is a true DROP TABLE. + if (cacheMode == CacheType.TACHYON) { + SharkEnv.tachyonUtil.dropTable(tableKey, hivePartitionKeyOpt = None) + } else { + SharkEnv.memoryMetadataManager.removeTable(dbName, tableName) + } + } else { + // The command is an ALTER TABLE DROP PARTITION + // Find the set of partition column values that specifies the partition being dropped. + val partCols: Seq[String] = hiveTable.getPartCols.map(_.getName) + for (partSpec <- partSpecs) { + val partColToValue: JavaMap[String, String] = partSpec.getPartSpecWithoutOperator + // String format for partition key: 'col1=value1/col2=value2/...' + val partKeyStr = MemoryMetadataManager.makeHivePartitionKeyStr(partCols, partColToValue) + if (cacheMode == CacheType.TACHYON) { + SharkEnv.tachyonUtil.dropTable(tableKey, Some(partKeyStr)) + } else { + val partitionedTable = getPartitionedTableWithAssertions(dbName, tableName) + getPartitionedTableWithAssertions(dbName, tableName).removePartition(partKeyStr) + } + } + } + } + + /** + * Updates Shark metastore for miscellaneous ALTER TABLE commands. + * + * @param hiveMetadataDb Namespace of the table to update. + * @param alterTableDesc Hive metadata object containing fields needed to handle various table + * update commands, such as ALTER TABLE RENAME TO. + * + */ + def alterTable( + hiveMetadataDb: Hive, + alterTableDesc: AlterTableDesc, + cacheMode: CacheType.CacheType) { + val dbName = hiveMetadataDb.getCurrentDatabase() + alterTableDesc.getOp() match { + case AlterTableDesc.AlterTableTypes.RENAME => { + val oldName = alterTableDesc.getOldName + val newName = alterTableDesc.getNewName + if (cacheMode == CacheType.TACHYON) { + val oldTableKey = MemoryMetadataManager.makeTableKey(dbName, oldName) + val newTableKey = MemoryMetadataManager.makeTableKey(dbName, newName) + SharkEnv.tachyonUtil.renameDirectory(oldTableKey, newTableKey) + } else { + SharkEnv.memoryMetadataManager.renameTable(dbName, oldName, newName) + } + } + case _ => { + // TODO(harvey): Support more ALTER TABLE commands, such as ALTER TABLE PARTITION RENAME TO. + throw new UnsupportedOperationException( + "Shark only requires a Shark DDL task for ALTER TABLE RENAME") + } + } + } + + private def getPartitionedTableWithAssertions( + dbName: String, + tableName: String): PartitionedMemoryTable = { + // Sanity checks: make sure that the table we're modifying exists in the Shark metastore and + // is actually partitioned. + val tableOpt = SharkEnv.memoryMetadataManager.getTable(dbName, tableName) + assert(tableOpt.isDefined, "Internal Error: table %s doesn't exist in Shark metastore.") + assert(tableOpt.get.isInstanceOf[PartitionedMemoryTable], + "Internal Error: table %s isn't partitioned when it should be.") + return tableOpt.get.asInstanceOf[PartitionedMemoryTable] + } + + override def getType = StageType.DDL + + override def getName = "DDL-SPARK" + + override def localizeMRTmpFilesImpl(ctx: Context) = Unit + +} diff --git a/src/main/scala/shark/execution/SharkExplainTask.scala b/src/main/scala/shark/execution/SharkExplainTask.scala index 2462a67b..10fcd6f3 100755 --- a/src/main/scala/shark/execution/SharkExplainTask.scala +++ b/src/main/scala/shark/execution/SharkExplainTask.scala @@ -18,17 +18,15 @@ package shark.execution import java.io.PrintStream -import java.lang.reflect.Method -import java.util.{Arrays, HashSet, List => JavaList} +import java.util.{List => JavaList} import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.exec.{ExplainTask, Task} import org.apache.hadoop.hive.ql.{Context, DriverContext, QueryPlan} -import org.apache.hadoop.hive.ql.plan.{Explain, ExplainWork} -import org.apache.hadoop.hive.ql.plan.api.StageType +import org.apache.hadoop.hive.ql.exec.{ExplainTask, Task} +import org.apache.hadoop.hive.ql.plan.ExplainWork import org.apache.hadoop.util.StringUtils import shark.LogHelper @@ -50,7 +48,7 @@ class SharkExplainTask extends Task[SharkExplainWork] with java.io.Serializable val hiveExplainTask = new ExplainTask override def execute(driverContext: DriverContext): Int = { - logInfo("Executing " + this.getClass.getName()) + logDebug("Executing " + this.getClass.getName()) hiveExplainTask.setWork(work) try { @@ -89,7 +87,7 @@ class SharkExplainTask extends Task[SharkExplainWork] with java.io.Serializable case e: Exception => { console.printError("Failed with exception " + e.getMessage(), "\n" + StringUtils.stringifyException(e)) - throw(e) + throw e 1 } } diff --git a/src/main/scala/shark/execution/SparkLoadTask.scala b/src/main/scala/shark/execution/SparkLoadTask.scala new file mode 100644 index 00000000..e219fd6d --- /dev/null +++ b/src/main/scala/shark/execution/SparkLoadTask.scala @@ -0,0 +1,448 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.execution + +import java.io.Serializable +import java.nio.ByteBuffer +import java.util.{Properties, Map => JavaMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.PathFilter +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.{Context, DriverContext} +import org.apache.hadoop.hive.ql.exec.{Task => HiveTask, Utilities} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table => HiveTable} +import org.apache.hadoop.hive.ql.plan.api.StageType +import org.apache.hadoop.hive.serde.Constants; +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat} + +import org.apache.spark.SerializableWritable +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +import shark.{LogHelper, SharkEnv, Utils} +import shark.execution.serialization.{KryoSerializer, JavaSerializer} +import shark.memstore2._ +import shark.tachyon.TachyonTableWriter +import shark.util.HiveUtils + + +/** + * Container for fields needed during SparkLoadTask execution. + * + * @param databaseName Namespace for the table being handled. + * @param tableName Name of the table being handled. + * @param commandType Enum representing the command that will be executed for the target table. See + * SparkLoadWork.CommandTypes for a description of which SQL commands correspond to each type. + * @param cacheMode Cache type that the RDD should be stored in (e.g., Spark heap). + */ +private[shark] +class SparkLoadWork( + val databaseName: String, + val tableName: String, + val commandType: SparkLoadWork.CommandTypes.Type, + val cacheMode: CacheType.CacheType) + extends Serializable { + + // Defined if the command is an INSERT and under these conditions: + // - Table is partitioned, and the partition being updated already exists + // (i.e., `partSpecOpt.isDefined == true`) + // - Table is not partitioned - Hive guarantees that data directories exist for updates on such + // tables. + var pathFilterOpt: Option[PathFilter] = None + + // A collection of partition key specifications for partitions to update. Each key is represented + // by a Map of (partitioning column -> value) pairs. + var partSpecs: Seq[JavaMap[String, String]] = Nil + + def addPartSpec(partSpec: JavaMap[String, String]) { + // Not the most efficient, but this method isn't called very often - either a single partition + // spec is passed for partition update, or all partitions are passed for cache load operation. + partSpecs = partSpecs ++ Seq(partSpec) + } +} + +object SparkLoadWork { + object CommandTypes extends Enumeration { + type Type = Value + + // Type of commands executed by the SparkLoadTask created from a SparkLoadWork. + // Corresponding SQL commands for each enum: + // - NEW_ENTRY: + // CACHE or ALTER TABLE
SET TBLPROPERTIES('shark.cache' = `true` ... ) + // - INSERT: + // INSERT INTO TABLE
or LOAD DATA INPATH '...' INTO
+ // - OVERWRITE: + // INSERT OVERWRITE TABLE
or LOAD DATA INPATH '...'' OVERWRITE INTO
+ val OVERWRITE, INSERT, NEW_ENTRY = Value + } + + /** + * Factory/helper method used in LOAD and INSERT INTO/OVERWRITE analysis. Sets all necessary + * fields in the SparkLoadWork returned. + */ + def apply( + db: Hive, + conf: HiveConf, + hiveTable: HiveTable, + partSpecOpt: Option[JavaMap[String, String]], + isOverwrite: Boolean): SparkLoadWork = { + val commandType = if (isOverwrite) { + SparkLoadWork.CommandTypes.OVERWRITE + } else { + SparkLoadWork.CommandTypes.INSERT + } + val cacheMode = CacheType.fromString(hiveTable.getProperty("shark.cache")) + val sparkLoadWork = new SparkLoadWork( + hiveTable.getDbName, + hiveTable.getTableName, + commandType, + cacheMode) + partSpecOpt.foreach(sparkLoadWork.addPartSpec(_)) + if (commandType == SparkLoadWork.CommandTypes.INSERT) { + if (hiveTable.isPartitioned) { + partSpecOpt.foreach { partSpec => + // None if the partition being updated doesn't exist yet. + val partitionOpt = Option(db.getPartition(hiveTable, partSpec, false /* forceCreate */)) + sparkLoadWork.pathFilterOpt = partitionOpt.map(part => + Utils.createSnapshotFilter(part.getPartitionPath, conf)) + } + } else { + sparkLoadWork.pathFilterOpt = Some(Utils.createSnapshotFilter(hiveTable.getPath, conf)) + } + } + sparkLoadWork + } +} + +/** + * A Hive task to load data from disk into the Shark cache. Handles INSERT INTO/OVERWRITE, + * LOAD INTO/OVERWRITE, CACHE, and CTAS commands. + */ +private[shark] +class SparkLoadTask extends HiveTask[SparkLoadWork] with Serializable with LogHelper { + + override def execute(driveContext: DriverContext): Int = { + logDebug("Executing " + this.getClass.getName) + + // Set the fair scheduler's pool using mapred.fairscheduler.pool if it is defined. + Option(conf.get("mapred.fairscheduler.pool")).foreach { pool => + SharkEnv.sc.setLocalProperty("spark.scheduler.pool", pool) + } + + val databaseName = work.databaseName + val tableName = work.tableName + // Set Spark's job description to be this query. + SharkEnv.sc.setJobGroup( + "shark.job", + s"Updating table $databaseName.$tableName for a(n) ${work.commandType}") + val hiveTable = Hive.get(conf).getTable(databaseName, tableName) + // Use HadoopTableReader to help with table scans. The `conf` passed is reused across HadoopRDD + // instantiations. + val hadoopReader = new HadoopTableReader(Utilities.getTableDesc(hiveTable), conf) + if (hiveTable.isPartitioned) { + loadPartitionedMemoryTable( + hiveTable, + work.partSpecs, + hadoopReader, + work.pathFilterOpt) + } else { + loadMemoryTable( + hiveTable, + hadoopReader, + work.pathFilterOpt) + } + // Success! + 0 + } + + /** + * Creates and materializes the in-memory, columnar RDD for a given input RDD. + * + * @param inputRdd A hadoop RDD, or a union of hadoop RDDs if the table is partitioned. + * @param serDeProps Properties used to initialize local ColumnarSerDe instantiations. This + * contains the output schema of the ColumnarSerDe and used to create its + * output object inspectors. + * @param broadcastedHiveConf Allows for sharing a Hive Configuration broadcast used to create + * the Hadoop `inputRdd`. + * @param inputOI Object inspector used to read rows from `inputRdd`. + * @param hivePartitionKeyOpt A defined Hive partition key if the RDD being loaded is part of a + * Hive-partitioned table. + */ + private def materialize( + inputRdd: RDD[_], + serDeProps: Properties, + broadcastedHiveConf: Broadcast[SerializableWritable[HiveConf]], + inputOI: StructObjectInspector, + tableKey: String, + hivePartitionKeyOpt: Option[String]) = { + val statsAcc = SharkEnv.sc.accumulableCollection(ArrayBuffer[(Int, TablePartitionStats)]()) + val tachyonWriter = if (work.cacheMode == CacheType.TACHYON) { + // Find the number of columns in the table schema using `serDeProps`. + val numColumns = serDeProps.getProperty(Constants.LIST_COLUMNS).split(',').size + // Use an additional row to store metadata (e.g. number of rows in each partition). + SharkEnv.tachyonUtil.createTableWriter(tableKey, hivePartitionKeyOpt, numColumns + 1) + } else { + null + } + val serializedOI = KryoSerializer.serialize(inputOI) + var transformedRdd = inputRdd.mapPartitionsWithIndex { case (partIndex, partIter) => + val serde = new ColumnarSerDe + serde.initialize(broadcastedHiveConf.value.value, serDeProps) + val localInputOI = KryoSerializer.deserialize[ObjectInspector](serializedOI) + var builder: Writable = null + partIter.foreach { row => + builder = serde.serialize(row.asInstanceOf[AnyRef], localInputOI) + } + if (builder == null) { + // Empty partition. + statsAcc += Tuple2(partIndex, new TablePartitionStats(Array.empty, 0)) + Iterator(new TablePartition(0, Array())) + } else { + statsAcc += Tuple2(partIndex, builder.asInstanceOf[TablePartitionBuilder].stats) + Iterator(builder.asInstanceOf[TablePartitionBuilder].build()) + } + } + // Run a job to materialize the RDD. + if (work.cacheMode == CacheType.TACHYON) { + // Put the table in Tachyon. + logInfo("Putting RDD for %s in Tachyon".format(tableKey)) + if (work.commandType == SparkLoadWork.CommandTypes.OVERWRITE && + SharkEnv.tachyonUtil.tableExists(tableKey, hivePartitionKeyOpt)) { + // For INSERT OVERWRITE, delete the old table or Hive partition directory, if it exists. + SharkEnv.tachyonUtil.dropTable(tableKey, hivePartitionKeyOpt) + } + tachyonWriter.createTable(ByteBuffer.allocate(0)) + transformedRdd = transformedRdd.mapPartitionsWithIndex { case(part, iter) => + val partition = iter.next() + partition.toTachyon.zipWithIndex.foreach { case(buf, column) => + tachyonWriter.writeColumnPartition(column, part, buf) + } + Iterator(partition) + } + } else { + transformedRdd.persist(StorageLevel.MEMORY_AND_DISK) + } + transformedRdd.context.runJob( + transformedRdd, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit)) + if (work.cacheMode == CacheType.TACHYON) { + tachyonWriter.updateMetadata(ByteBuffer.wrap(JavaSerializer.serialize(statsAcc.value.toMap))) + } + (transformedRdd, statsAcc.value) + } + + /** Returns a MemoryTable for the given Hive table. */ + private def getOrCreateMemoryTable(hiveTable: HiveTable): MemoryTable = { + val databaseName = hiveTable.getDbName + val tableName = hiveTable.getTableName + work.commandType match { + case SparkLoadWork.CommandTypes.NEW_ENTRY => { + // This is a new entry, e.g. we are caching a new table or partition. + // Create a new MemoryTable object and return that. + SharkEnv.memoryMetadataManager.createMemoryTable(databaseName, tableName, work.cacheMode) + } + case _ => { + // This is an existing entry (e.g. we are handling an INSERT or INSERT OVERWRITE). + // Get the MemoryTable object from the Shark metastore. + val tableOpt = SharkEnv.memoryMetadataManager.getTable(databaseName, tableName) + assert(tableOpt.exists(_.isInstanceOf[MemoryTable]), + "Memory table being updated cannot be found in the Shark metastore.") + tableOpt.get.asInstanceOf[MemoryTable] + } + } + } + + /** + * Handles loading data from disk into the Shark cache for non-partitioned tables. + * + * @param hiveTable Hive metadata object representing the target table. + * @param hadoopReader Used to create a HadoopRDD from the table's data directory. + * @param pathFilterOpt Defined for INSERT update operations (e.g., INSERT INTO) and passed to + * hadoopReader#makeRDDForTable() to determine which new files should be read from the table's + * data directory - see the SparkLoadWork#apply() factory method for an example of how a + * path filter is created. + */ + private def loadMemoryTable( + hiveTable: HiveTable, + hadoopReader: HadoopTableReader, + pathFilterOpt: Option[PathFilter]) { + val databaseName = hiveTable.getDbName + val tableName = hiveTable.getTableName + val tableSchema = hiveTable.getSchema + val serDe = hiveTable.getDeserializer + serDe.initialize(conf, tableSchema) + // Scan the Hive table's data directory. + val inputRDD = hadoopReader.makeRDDForTable(hiveTable, serDe.getClass, pathFilterOpt) + // Transform the HadoopRDD to an RDD[TablePartition]. + val (tablePartitionRDD, tableStats) = materialize( + inputRDD, + tableSchema, + hadoopReader.broadcastedHiveConf, + serDe.getObjectInspector.asInstanceOf[StructObjectInspector], + MemoryMetadataManager.makeTableKey(databaseName, tableName), + hivePartitionKeyOpt = None) + if (work.cacheMode != CacheType.TACHYON) { + val memoryTable = getOrCreateMemoryTable(hiveTable) + work.commandType match { + case (SparkLoadWork.CommandTypes.OVERWRITE | SparkLoadWork.CommandTypes.NEW_ENTRY) => + memoryTable.put(tablePartitionRDD, tableStats.toMap) + case SparkLoadWork.CommandTypes.INSERT => { + memoryTable.update(tablePartitionRDD, tableStats) + } + } + } + } + + /** + * Returns the created (for CommandType.NEW_ENTRY) or fetched (for CommandType.INSERT or + * OVERWRITE) PartitionedMemoryTable corresponding to `partSpecs`. + * + * @param hiveTable The Hive Table. + * @param partSpecs A map of (partitioning column -> corresponding value) that uniquely + * identifies the partition being created or updated. + */ + private def getOrCreatePartitionedMemoryTable( + hiveTable: HiveTable, + partSpecs: JavaMap[String, String]): PartitionedMemoryTable = { + val databaseName = hiveTable.getDbName + val tableName = hiveTable.getTableName + work.commandType match { + case SparkLoadWork.CommandTypes.NEW_ENTRY => { + SharkEnv.memoryMetadataManager.createPartitionedMemoryTable( + databaseName, + tableName, + work.cacheMode, + hiveTable.getParameters) + } + case _ => { + SharkEnv.memoryMetadataManager.getTable(databaseName, tableName) match { + case Some(table: PartitionedMemoryTable) => table + case _ => { + val tableOpt = SharkEnv.memoryMetadataManager.getTable(databaseName, tableName) + assert(tableOpt.exists(_.isInstanceOf[PartitionedMemoryTable]), + "Partitioned memory table being updated cannot be found in the Shark metastore.") + tableOpt.get.asInstanceOf[PartitionedMemoryTable] + } + } + } + } + } + + /** + * Handles loading data from disk into the Shark cache for non-partitioned tables. + * + * @param hiveTable Hive metadata object representing the target table. + * @param partSpecs Sequence of partition key specifications that contains either a single key, + * or all of the table's partition keys. This is because only one partition specficiation is + * allowed for each append or overwrite command, and new cache entries (i.e, for a CACHE + * comand) are full table scans. + * @param hadoopReader Used to create a HadoopRDD from each partition's data directory. + * @param pathFilterOpt Defined for INSERT update operations (e.g., INSERT INTO) and passed to + * hadoopReader#makeRDDForTable() to determine which new files should be read from the table + * partition's data directory - see the SparkLoadWork#apply() factory method for an example of + * how a path filter is created. + */ + private def loadPartitionedMemoryTable( + hiveTable: HiveTable, + partSpecs: Seq[JavaMap[String, String]], + hadoopReader: HadoopTableReader, + pathFilterOpt: Option[PathFilter]) { + val databaseName = hiveTable.getDbName + val tableName = hiveTable.getTableName + val partCols = hiveTable.getPartCols.map(_.getName) + + for (partSpec <- partSpecs) { + // Read, materialize, and store a columnar-backed RDD for `partSpec`. + val partitionKey = MemoryMetadataManager.makeHivePartitionKeyStr(partCols, partSpec) + val partition = db.getPartition(hiveTable, partSpec, false /* forceCreate */) + val partSerDe = partition.getDeserializer() + val partSchema = partition.getSchema + partSerDe.initialize(conf, partSchema) + // Get a UnionStructObjectInspector that unifies the two StructObjectInspectors for the table + // columns and the partition columns. + val unionOI = HiveUtils.makeUnionOIForPartitionedTable(partSchema, partSerDe) + // Create a HadoopRDD for the file scan. + val inputRDD = hadoopReader.makeRDDForPartitionedTable( + Map(partition -> partSerDe.getClass), pathFilterOpt) + val (tablePartitionRDD, tableStats) = materialize( + inputRDD, + SparkLoadTask.addPartitionInfoToSerDeProps(partCols, partition.getSchema), + hadoopReader.broadcastedHiveConf, + unionOI, + MemoryMetadataManager.makeTableKey(databaseName, tableName), + Some(partitionKey)) + if (work.cacheMode != CacheType.TACHYON) { + // Handle appends or overwrites. + val partitionedTable = getOrCreatePartitionedMemoryTable(hiveTable, partSpec) + if (partitionedTable.containsPartition(partitionKey) && + (work.commandType == SparkLoadWork.CommandTypes.INSERT)) { + partitionedTable.updatePartition(partitionKey, tablePartitionRDD, tableStats) + } else { + partitionedTable.putPartition(partitionKey, tablePartitionRDD, tableStats.toMap) + } + } + } + } + + override def getType = StageType.MAPRED + + override def getName = "MAPRED-LOAD-SPARK" + + override def localizeMRTmpFilesImpl(ctx: Context) = Unit +} + + +object SparkLoadTask { + + /** + * Returns a copy of `baseSerDeProps` with the names and types for the table's partitioning + * columns appended to respective row metadata properties. + */ + private def addPartitionInfoToSerDeProps( + partCols: Seq[String], + baseSerDeProps: Properties): Properties = { + val serDeProps = new Properties(baseSerDeProps) + + // Column names specified by the Constants.LIST_COLUMNS key are delimited by ",". + // E.g., for a table created from + // CREATE TABLE page_views(key INT, val BIGINT), PARTITIONED BY (dt STRING, country STRING), + // `columnNameProperties` will be "key,val". We want to append the "dt, country" partition + // column names to it, and reset the Constants.LIST_COLUMNS entry in the SerDe properties. + var columnNameProperties: String = serDeProps.getProperty(Constants.LIST_COLUMNS) + columnNameProperties += "," + partCols.mkString(",") + serDeProps.setProperty(Constants.LIST_COLUMNS, columnNameProperties) + + // `None` if column types are missing. By default, Hive SerDeParameters initialized by the + // ColumnarSerDe will treat all columns as having string types. + // Column types specified by the Constants.LIST_COLUMN_TYPES key are delimited by ":" + // E.g., for the CREATE TABLE example above, if `columnTypeProperties` is defined, then it + // will be "int:bigint". Partition columns are strings, so "string:string" should be appended. + val columnTypePropertiesOpt = Option(serDeProps.getProperty(Constants.LIST_COLUMN_TYPES)) + columnTypePropertiesOpt.foreach { columnTypeProperties => + serDeProps.setProperty(Constants.LIST_COLUMN_TYPES, + columnTypeProperties + (":" + Constants.STRING_TYPE_NAME * partCols.size)) + } + serDeProps + } +} diff --git a/src/main/scala/shark/execution/SparkTask.scala b/src/main/scala/shark/execution/SparkTask.scala index 32241a47..f878ce0c 100755 --- a/src/main/scala/shark/execution/SparkTask.scala +++ b/src/main/scala/shark/execution/SparkTask.scala @@ -54,7 +54,7 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper { def tableRdd: Option[TableRDD] = _tableRdd override def execute(driverContext: DriverContext): Int = { - logInfo("Executing " + this.getClass.getName) + logDebug("Executing " + this.getClass.getName) val ctx = driverContext.getCtx() @@ -86,17 +86,15 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper { initializeTableScanTableDesc(tableScanOps) - // Initialize the Hive query plan. This gives us all the object inspectors. - initializeAllHiveOperators(terminalOp) - terminalOp.initializeMasterOnAll() // Set Spark's job description to be this query. - SharkEnv.sc.setJobDescription(work.pctx.getContext.getCmd) + SharkEnv.sc.setJobGroup("shark.job", work.pctx.getContext.getCmd) - // Set the fair scheduler's pool. - SharkEnv.sc.setLocalProperty("spark.scheduler.cluster.fair.pool", - conf.get("mapred.fairscheduler.pool")) + // Set the fair scheduler's pool using mapred.fairscheduler.pool if it is defined. + Option(conf.get("mapred.fairscheduler.pool")).foreach { pool => + SharkEnv.sc.setLocalProperty("spark.scheduler.pool", pool) + } val sinkRdd = terminalOp.execute().asInstanceOf[RDD[Any]] @@ -116,6 +114,7 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper { // topToTable maps Hive's TableScanOperator to the Table object. val topToTable: JHashMap[HiveTableScanOperator, Table] = work.pctx.getTopToTable() + val emptyPartnArray = new Array[Partition](0) // Add table metadata to TableScanOperators topOps.foreach { op => op.table = topToTable.get(op.hiveOp) @@ -127,7 +126,8 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper { work.pctx.getOpToPartPruner().get(op.hiveOp), work.pctx.getConf(), "", work.pctx.getPrunedPartitions()) - op.parts = ppl.getConfirmedPartns.toArray ++ ppl.getUnknownPartns.toArray + op.parts = ppl.getConfirmedPartns.toArray(emptyPartnArray) ++ + ppl.getUnknownPartns.toArray(emptyPartnArray) val allParts = op.parts ++ ppl.getDeniedPartns.toArray if (allParts.size == 0) { op.firstConfPartDesc = new PartitionDesc(op.tableDesc, null) @@ -138,28 +138,6 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper { } } - def initializeAllHiveOperators(terminalOp: TerminalOperator) { - // Need to guarantee all parents are initialized before the child. - val topOpList = new scala.collection.mutable.MutableList[HiveTopOperator] - val queue = new scala.collection.mutable.Queue[Operator[_]] - queue.enqueue(terminalOp) - - while (!queue.isEmpty) { - val current = queue.dequeue() - current match { - case op: HiveTopOperator => topOpList += op - case _ => Unit - } - queue ++= current.parentOperators - } - - // Run the initialization. This guarantees that upstream operators are - // initialized before downstream ones. - topOpList.reverse.foreach { topOp => - topOp.initializeHiveTopOperator() - } - } - override def getType = StageType.MAPRED override def getName = "MAPRED-SPARK" diff --git a/src/main/scala/shark/execution/TableReader.scala b/src/main/scala/shark/execution/TableReader.scala new file mode 100644 index 00000000..0687c7c8 --- /dev/null +++ b/src/main/scala/shark/execution/TableReader.scala @@ -0,0 +1,250 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.execution + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +import org.apache.hadoop.hive.metastore.api.Constants.META_TABLE_PARTITION_COLUMNS +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.plan.TableDesc + +import org.apache.spark.rdd.{EmptyRDD, RDD, UnionRDD} + +import shark.{LogHelper, SharkEnv} +import shark.api.QueryExecutionException +import shark.execution.serialization.JavaSerializer +import shark.memstore2.{MemoryMetadataManager, Table, TablePartition, TablePartitionStats} +import shark.tachyon.TachyonException + + +/** + * A trait for subclasses that handle table scans. In Shark, there is one subclass for each + * type of table storage: HeapTableReader for Shark tables in Spark's block manager, + * TachyonTableReader for tables in Tachyon, and HadoopTableReader for Hive tables in a filesystem. + */ +trait TableReader extends LogHelper { + + type PruningFunctionType = (RDD[_], collection.Map[Int, TablePartitionStats]) => RDD[_] + + def makeRDDForTable( + hiveTable: HiveTable, + pruningFnOpt: Option[PruningFunctionType] = None + ): RDD[_] + + def makeRDDForPartitionedTable( + partitions: Seq[HivePartition], + pruningFnOpt: Option[PruningFunctionType] = None + ): RDD[_] +} + +/** Helper class for scanning tables stored in Tachyon. */ +class TachyonTableReader(@transient _tableDesc: TableDesc) extends TableReader { + + // Split from 'databaseName.tableName' + private val _tableNameSplit = _tableDesc.getTableName.split('.') + private val _databaseName = _tableNameSplit(0) + private val _tableName = _tableNameSplit(1) + + override def makeRDDForTable( + hiveTable: HiveTable, + pruningFnOpt: Option[PruningFunctionType] = None + ): RDD[_] = { + val tableKey = MemoryMetadataManager.makeTableKey(_databaseName, _tableName) + makeRDD(tableKey, hivePartitionKeyOpt = None, pruningFnOpt) + } + + override def makeRDDForPartitionedTable( + partitions: Seq[HivePartition], + pruningFnOpt: Option[PruningFunctionType] = None): RDD[_] = { + val tableKey = MemoryMetadataManager.makeTableKey(_databaseName, _tableName) + val hivePartitionRDDs = partitions.map { hivePartition => + val partDesc = Utilities.getPartitionDesc(hivePartition) + // Get partition field info + val partSpec = partDesc.getPartSpec() + val partProps = partDesc.getProperties() + + val partColsDelimited = partProps.getProperty(META_TABLE_PARTITION_COLUMNS) + // Partitioning columns are delimited by "/" + val partCols = partColsDelimited.trim().split("/").toSeq + // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'. + val partValues = if (partSpec == null) { + Array.fill(partCols.size)(new String) + } else { + partCols.map(col => new String(partSpec.get(col))).toArray + } + val partitionKeyStr = MemoryMetadataManager.makeHivePartitionKeyStr(partCols, partSpec) + val hivePartitionRDD = makeRDD(tableKey, Some(partitionKeyStr), pruningFnOpt) + hivePartitionRDD.mapPartitions { iter => + if (iter.hasNext) { + // Map each tuple to a row object + val rowWithPartArr = new Array[Object](2) + iter.map { value => + rowWithPartArr.update(0, value.asInstanceOf[Object]) + rowWithPartArr.update(1, partValues) + rowWithPartArr.asInstanceOf[Object] + } + } else { + Iterator.empty + } + } + } + if (hivePartitionRDDs.size > 0) { + new UnionRDD(hivePartitionRDDs.head.context, hivePartitionRDDs) + } else { + new EmptyRDD[Object](SharkEnv.sc) + } + } + + private def makeRDD( + tableKey: String, + hivePartitionKeyOpt: Option[String], + pruningFnOpt: Option[PruningFunctionType]): RDD[Any] = { + // Check that the table is in Tachyon. + if (!SharkEnv.tachyonUtil.tableExists(tableKey, hivePartitionKeyOpt)) { + throw new TachyonException("Table " + tableKey + " does not exist in Tachyon") + } + val tableRDDsAndStats = SharkEnv.tachyonUtil.createRDD(tableKey, hivePartitionKeyOpt) + val prunedRDDs = if (pruningFnOpt.isDefined) { + val pruningFn = pruningFnOpt.get + tableRDDsAndStats.map(tableRDDWithStats => + pruningFn(tableRDDWithStats._1, tableRDDWithStats._2).asInstanceOf[RDD[Any]]) + } else { + tableRDDsAndStats.map(tableRDDAndStats => tableRDDAndStats._1.asInstanceOf[RDD[Any]]) + } + val unionedRDD = if (prunedRDDs.isEmpty) { + new EmptyRDD[TablePartition](SharkEnv.sc) + } else { + new UnionRDD(SharkEnv.sc, prunedRDDs) + } + unionedRDD.asInstanceOf[RDD[Any]] + } + +} + +/** Helper class for scanning tables stored in Spark's block manager */ +class HeapTableReader(@transient _tableDesc: TableDesc) extends TableReader { + + // Split from 'databaseName.tableName' + private val _tableNameSplit = _tableDesc.getTableName.split('.') + private val _databaseName = _tableNameSplit(0) + private val _tableName = _tableNameSplit(1) + + /** Fetches and optionally prunes the RDD for `_tableName` from the Shark metastore. */ + override def makeRDDForTable( + hiveTable: HiveTable, + pruningFnOpt: Option[PruningFunctionType] = None + ): RDD[_] = { + logInfo("Loading table %s.%s from Spark block manager".format(_databaseName, _tableName)) + val tableOpt = SharkEnv.memoryMetadataManager.getMemoryTable(_databaseName, _tableName) + if (tableOpt.isEmpty) { + throwMissingTableException() + } + + val table = tableOpt.get + val tableRdd = table.getRDD.get + val tableStats = table.getStats.get + // Prune if an applicable function is given. + pruningFnOpt.map(_(tableRdd, tableStats)).getOrElse(tableRdd) + } + + /** + * Fetches an RDD from the Shark metastore for each partition key given. Returns a single, unioned + * RDD representing all of the specified partition keys. + * + * @param partitions A collection of Hive-partition metadata, such as partition columns and + * partition key specifications. + */ + override def makeRDDForPartitionedTable( + partitions: Seq[HivePartition], + pruningFnOpt: Option[PruningFunctionType] = None + ): RDD[_] = { + val hivePartitionRDDs = partitions.map { partition => + val partDesc = Utilities.getPartitionDesc(partition) + // Get partition field info + val partSpec = partDesc.getPartSpec() + val partProps = partDesc.getProperties() + + val partColsDelimited = partProps.getProperty(META_TABLE_PARTITION_COLUMNS) + // Partitioning columns are delimited by "/" + val partCols = partColsDelimited.trim().split("/").toSeq + // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'. + val partValues = if (partSpec == null) { + Array.fill(partCols.size)(new String) + } else { + partCols.map(col => new String(partSpec.get(col))).toArray + } + + val partitionKeyStr = MemoryMetadataManager.makeHivePartitionKeyStr(partCols, partSpec) + val hivePartitionedTableOpt = SharkEnv.memoryMetadataManager.getPartitionedTable( + _databaseName, _tableName) + if (hivePartitionedTableOpt.isEmpty) { + throwMissingTableException() + } + val hivePartitionedTable = hivePartitionedTableOpt.get + + val rddAndStatsOpt = hivePartitionedTable.getPartitionAndStats(partitionKeyStr) + if (rddAndStatsOpt.isEmpty) { + throwMissingPartitionException(partitionKeyStr) + } + val (hivePartitionRDD, hivePartitionStats) = (rddAndStatsOpt.get._1, rddAndStatsOpt.get._2) + val prunedPartitionRDD = pruningFnOpt.map(_(hivePartitionRDD, hivePartitionStats)) + .getOrElse(hivePartitionRDD) + prunedPartitionRDD.mapPartitions { iter => + if (iter.hasNext) { + // Map each tuple to a row object + val rowWithPartArr = new Array[Object](2) + iter.map { value => + rowWithPartArr.update(0, value.asInstanceOf[Object]) + rowWithPartArr.update(1, partValues) + rowWithPartArr.asInstanceOf[Object] + } + } else { + Iterator.empty + } + } + } + if (hivePartitionRDDs.size > 0) { + new UnionRDD(hivePartitionRDDs.head.context, hivePartitionRDDs) + } else { + new EmptyRDD[Object](SharkEnv.sc) + } + } + + /** + * Thrown if the table identified by the (_databaseName, _tableName) pair cannot be found in + * the Shark metastore. + */ + private def throwMissingTableException() { + logError("""|Table %s.%s not found in block manager. + |Are you trying to access a cached table from a Shark session other than the one + |in which it was created?""".stripMargin.format(_databaseName, _tableName)) + throw new QueryExecutionException("Cached table not found") + } + + /** + * Thrown if the table partition identified by the (_databaseName, _tableName, partValues) tuple + * cannot be found in the Shark metastore. + */ + private def throwMissingPartitionException(partValues: String) { + logError("""|Partition %s for table %s.%s not found in block manager. + |Are you trying to access a cached table from a Shark session other than the one in + |which it was created?""".stripMargin.format(partValues, _databaseName, _tableName)) + throw new QueryExecutionException("Cached table partition not found") + } +} diff --git a/src/main/scala/shark/execution/TableScanOperator.scala b/src/main/scala/shark/execution/TableScanOperator.scala index 9fda4702..eaba7e9b 100755 --- a/src/main/scala/shark/execution/TableScanOperator.scala +++ b/src/main/scala/shark/execution/TableScanOperator.scala @@ -18,40 +18,49 @@ package shark.execution import java.util.{ArrayList, Arrays} + +import scala.collection.JavaConversions._ import scala.reflect.BeanProperty -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.Constants.META_TABLE_PARTITION_COLUMNS import org.apache.hadoop.hive.ql.exec.{TableScanOperator => HiveTableScanOperator} import org.apache.hadoop.hive.ql.exec.{MapSplitPruning, Utilities} -import org.apache.hadoop.hive.ql.io.HiveInputFormat import org.apache.hadoop.hive.ql.metadata.{Partition, Table} -import org.apache.hadoop.hive.ql.plan.{PlanUtils, PartitionDesc, TableDesc} +import org.apache.hadoop.hive.ql.plan.{PartitionDesc, TableDesc, TableScanDesc} +import org.apache.hadoop.hive.serde.Constants import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.io.Writable -import org.apache.spark.rdd.{PartitionPruningRDD, RDD, UnionRDD} +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} -import shark.{SharkConfVars, SharkEnv, Utils} -import shark.api.QueryExecutionException +import shark.{LogHelper, SharkConfVars, SharkEnv} import shark.execution.optimization.ColumnPruner -import shark.execution.serialization.{XmlSerializer, JavaSerializer} -import shark.memstore2.{CacheType, TablePartition, TablePartitionStats} -import shark.tachyon.TachyonException +import shark.memstore2.CacheType +import shark.memstore2.CacheType._ +import shark.memstore2.{ColumnarSerDe, MemoryMetadataManager} +import shark.memstore2.{TablePartition, TablePartitionStats} +import shark.util.HiveUtils /** * The TableScanOperator is used for scanning any type of Shark or Hive table. */ -class TableScanOperator extends TopOperator[HiveTableScanOperator] with HiveTopOperator { +class TableScanOperator extends TopOperator[TableScanDesc] { + // TODO(harvey): Try to use 'TableDesc' for execution and save 'Table' for analysis/planning. + // Decouple `Table` from TableReader and ColumnPruner. @transient var table: Table = _ + @transient var hiveOp: HiveTableScanOperator = _ + // Metadata for Hive-partitions (i.e if the table was created from PARTITION BY). NULL if this // table isn't Hive-partitioned. Set in SparkTask::initializeTableScanTableDesc(). - @transient var parts: Array[Object] = _ + @transient var parts: Array[Partition] = _ + + // For convenience, a local copy of the HiveConf for this task. + @transient var localHConf: HiveConf = _ // PartitionDescs are used during planning in Hive. This reference to a single PartitionDesc // is used to initialize partition ObjectInspectors. @@ -62,283 +71,210 @@ class TableScanOperator extends TopOperator[HiveTableScanOperator] with HiveTopO @BeanProperty var firstConfPartDesc: PartitionDesc = _ @BeanProperty var tableDesc: TableDesc = _ - @BeanProperty var localHconf: HiveConf = _ - /** - * Initialize the hive TableScanOperator. This initialization propagates - * downstream. When all Hive TableScanOperators are initialized, the entire - * Hive query plan operators are initialized. - */ - override def initializeHiveTopOperator() { + // True if table data is stored the Spark heap. + @BeanProperty var isInMemoryTableScan: Boolean = _ - val rowObjectInspector = { - if (parts == null) { - val serializer = tableDesc.getDeserializerClass().newInstance() - serializer.initialize(hconf, tableDesc.getProperties) - serializer.getObjectInspector() - } else { - val partProps = firstConfPartDesc.getProperties() - val tableDeser = firstConfPartDesc.getDeserializerClass().newInstance() - tableDeser.initialize(hconf, partProps) - val partCols = partProps.getProperty(META_TABLE_PARTITION_COLUMNS) - val partNames = new ArrayList[String] - val partObjectInspectors = new ArrayList[ObjectInspector] - partCols.trim().split("/").foreach{ key => - partNames.add(key) - partObjectInspectors.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector) - } + @BeanProperty var cacheMode: CacheType.CacheType = _ - // No need to lock this one (see SharkEnv.objectInspectorLock) because - // this is called on the master only. - val partObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( - partNames, partObjectInspectors) - val oiList = Arrays.asList( - tableDeser.getObjectInspector().asInstanceOf[StructObjectInspector], - partObjectInspector.asInstanceOf[StructObjectInspector]) - // new oi is union of table + partition object inspectors - ObjectInspectorFactory.getUnionStructObjectInspector(oiList) - } - } - setInputObjectInspector(0, rowObjectInspector) - super.initializeHiveTopOperator() + override def initializeOnMaster() { + // Create a local copy of the HiveConf that will be assigned job properties and, for disk reads, + // broadcasted to slaves. + localHConf = new HiveConf(super.hconf) + cacheMode = CacheType.fromString( + tableDesc.getProperties().get("shark.cache").asInstanceOf[String]) + isInMemoryTableScan = SharkEnv.memoryMetadataManager.containsTable( + table.getDbName, table.getTableName) } - override def initializeOnMaster() { - localHconf = super.hconf + override def outputObjectInspector() = { + if (parts == null) { + val serializer = if (isInMemoryTableScan || cacheMode == CacheType.TACHYON) { + new ColumnarSerDe + } else { + tableDesc.getDeserializerClass().newInstance() + } + serializer.initialize(hconf, tableDesc.getProperties) + serializer.getObjectInspector() + } else { + val partProps = firstConfPartDesc.getProperties() + val partSerDe = if (isInMemoryTableScan || cacheMode == CacheType.TACHYON) { + new ColumnarSerDe + } else { + firstConfPartDesc.getDeserializerClass().newInstance() + } + partSerDe.initialize(hconf, partProps) + HiveUtils.makeUnionOIForPartitionedTable(partProps, partSerDe) + } } override def execute(): RDD[_] = { assert(parentOperators.size == 0) - val tableKey: String = tableDesc.getTableName.split('.')(1) + + val tableNameSplit = tableDesc.getTableName.split('.') // Split from 'databaseName.tableName' + val databaseName = tableNameSplit(0) + val tableName = tableNameSplit(1) // There are three places we can load the table from. - // 1. Tachyon table - // 2. Spark heap (block manager) + // 1. Spark heap (block manager), accessed through the Shark MemoryMetadataManager + // 2. Tachyon table // 3. Hive table on HDFS (or other Hadoop storage) - - val cacheMode = CacheType.fromString( - tableDesc.getProperties().get("shark.cache").asInstanceOf[String]) - if (cacheMode == CacheType.heap) { - // Table should be in Spark heap (block manager). - val rdd = SharkEnv.memoryMetadataManager.get(tableKey).getOrElse { - logError("""|Table %s not found in block manager. - |Are you trying to access a cached table from a Shark session other than - |the one in which it was created?""".stripMargin.format(tableKey)) - throw(new QueryExecutionException("Cached table not found")) + // TODO(harvey): Pruning Hive-partitioned, cached tables isn't supported yet. + if (isInMemoryTableScan || cacheMode == CacheType.TACHYON) { + if (isInMemoryTableScan) { + assert(cacheMode == CacheType.MEMORY || cacheMode == CacheType.MEMORY_ONLY, + "Table %s.%s is in Shark metastore, but its cacheMode (%s) indicates otherwise". + format(databaseName, tableName, cacheMode)) } - logInfo("Loading table " + tableKey + " from Spark block manager") - createPrunedRdd(tableKey, rdd) - } else if (cacheMode == CacheType.tachyon) { - // Table is in Tachyon. - if (!SharkEnv.tachyonUtil.tableExists(tableKey)) { - throw new TachyonException("Table " + tableKey + " does not exist in Tachyon") + val tableReader = if (cacheMode == CacheType.TACHYON) { + new TachyonTableReader(tableDesc) + } else { + new HeapTableReader(tableDesc) } - logInfo("Loading table " + tableKey + " from Tachyon.") - - var indexToStats: collection.Map[Int, TablePartitionStats] = - SharkEnv.memoryMetadataManager.getStats(tableKey).getOrElse(null) - - if (indexToStats == null) { - val statsByteBuffer = SharkEnv.tachyonUtil.getTableMetadata(tableKey) - indexToStats = JavaSerializer.deserialize[collection.Map[Int, TablePartitionStats]]( - statsByteBuffer.array()) - logInfo("Loading table " + tableKey + " stats from Tachyon.") - SharkEnv.memoryMetadataManager.putStats(tableKey, indexToStats) + if (table.isPartitioned) { + tableReader.makeRDDForPartitionedTable(parts, Some(createPrunedRdd _)) + } else { + tableReader.makeRDDForTable(table, Some(createPrunedRdd _)) } - createPrunedRdd(tableKey, SharkEnv.tachyonUtil.createRDD(tableKey)) } else { // Table is a Hive table on HDFS (or other Hadoop storage). - super.execute() + makeRDDFromHadoop() } } - private def createPrunedRdd(tableKey: String, rdd: RDD[_]): RDD[_] = { - // Stats used for map pruning. - val indexToStats: collection.Map[Int, TablePartitionStats] = - SharkEnv.memoryMetadataManager.getStats(tableKey).get - + private def createPrunedRdd( + rdd: RDD[_], + indexToStats: collection.Map[Int, TablePartitionStats]): RDD[_] = { // Run map pruning if the flag is set, there exists a filter predicate on // the input table and we have statistics on the table. val columnsUsed = new ColumnPruner(this, table).columnsUsed - SharkEnv.tachyonUtil.pushDownColumnPruning(rdd, columnsUsed) - - val prunedRdd: RDD[_] = - if (SharkConfVars.getBoolVar(localHconf, SharkConfVars.MAP_PRUNING) && - childOperators(0).isInstanceOf[FilterOperator] && - indexToStats.size == rdd.partitions.size) { - - val startTime = System.currentTimeMillis - val printPruneDebug = SharkConfVars.getBoolVar( - localHconf, SharkConfVars.MAP_PRUNING_PRINT_DEBUG) - - // Must initialize the condition evaluator in FilterOperator to get the - // udfs and object inspectors set. - val filterOp = childOperators(0).asInstanceOf[FilterOperator] - filterOp.initializeOnSlave() - - def prunePartitionFunc(index: Int): Boolean = { - if (printPruneDebug) { - logInfo("\nPartition " + index + "\n" + indexToStats(index)) - } - // Only test for pruning if we have stats on the column. - val partitionStats = indexToStats(index) - if (partitionStats != null && partitionStats.stats != null) { - MapSplitPruning.test(partitionStats, filterOp.conditionEvaluator) - } else { - true - } - } - // Do the pruning. - val prunedRdd = PartitionPruningRDD.create(rdd, prunePartitionFunc) - val timeTaken = System.currentTimeMillis - startTime - logInfo("Map pruning %d partitions into %s partitions took %d ms".format( - rdd.partitions.size, prunedRdd.partitions.size, timeTaken)) - prunedRdd - } else { - rdd + if (!table.isPartitioned && cacheMode == CacheType.TACHYON) { + SharkEnv.tachyonUtil.pushDownColumnPruning(rdd, columnsUsed) + } + + val shouldPrune = SharkConfVars.getBoolVar(localHConf, SharkConfVars.MAP_PRUNING) && + childOperators(0).isInstanceOf[FilterOperator] && + indexToStats.size == rdd.partitions.size + + val prunedRdd: RDD[_] = if (shouldPrune) { + val startTime = System.currentTimeMillis + val printPruneDebug = SharkConfVars.getBoolVar( + localHConf, SharkConfVars.MAP_PRUNING_PRINT_DEBUG) + + // Must initialize the condition evaluator in FilterOperator to get the + // udfs and object inspectors set. + val filterOp = childOperators(0).asInstanceOf[FilterOperator] + filterOp.initializeOnSlave() + + def prunePartitionFunc(index: Int): Boolean = { + if (printPruneDebug) { + logInfo("\nPartition " + index + "\n" + indexToStats(index)) + } + // Only test for pruning if we have stats on the column. + val partitionStats = indexToStats(index) + if (partitionStats != null && partitionStats.stats != null) { + MapSplitPruning.test(partitionStats, filterOp.conditionEvaluator) + } else { + true + } } + // Do the pruning. + val prunedRdd = PartitionPruningRDD.create(rdd, prunePartitionFunc) + val timeTaken = System.currentTimeMillis - startTime + logInfo("Map pruning %d partitions into %s partitions took %d ms".format( + rdd.partitions.size, prunedRdd.partitions.size, timeTaken)) + prunedRdd + } else { + rdd + } + prunedRdd.mapPartitions { iter => if (iter.hasNext) { - val tablePartition = iter.next.asInstanceOf[TablePartition] + val tablePartition1 = iter.next() + val tablePartition = tablePartition1.asInstanceOf[TablePartition] tablePartition.prunedIterator(columnsUsed) - //tablePartition.iterator } else { - Iterator() + Iterator.empty } } } /** - * Create a RDD representing the table (with or without partitions). + * Create an RDD for a table stored in Hadoop. */ - override def preprocessRdd(rdd: RDD[_]): RDD[_] = { + def makeRDDFromHadoop(): RDD[_] = { + // Try to have the InputFormats filter predicates. + TableScanOperator.addFilterExprToConf(localHConf, hiveOp) + + val hadoopReader = new HadoopTableReader(tableDesc, localHConf) if (table.isPartitioned) { - logInfo("Making %d Hive partitions".format(parts.size)) - makePartitionRDD(rdd) + logDebug("Making %d Hive partitions".format(parts.size)) + // The returned RDD contains arrays of size two with the elements as + // (deserialized row, column partition value). + return hadoopReader.makeRDDForPartitionedTable(parts) } else { - val tablePath = table.getPath.toString - val ifc = table.getInputFormatClass - .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] - logInfo("Table input: %s".format(tablePath)) - createHadoopRdd(tablePath, ifc) + // The returned RDD contains deserialized row Objects. + return hadoopReader.makeRDDForTable(table) } } - override def processPartition(index: Int, iter: Iterator[_]): Iterator[_] = { - val deserializer = tableDesc.getDeserializerClass().newInstance() - deserializer.initialize(localHconf, tableDesc.getProperties) - iter.map { value => - value match { - case rowWithPart: Array[Object] => rowWithPart - case v: Writable => deserializer.deserialize(v) - case _ => throw new RuntimeException("Failed to match " + value.toString) - } - } - } + // All RDD processing is done in execute(). + override def processPartition(split: Int, iter: Iterator[_]): Iterator[_] = + throw new UnsupportedOperationException("TableScanOperator.processPartition()") + +} + + +object TableScanOperator extends LogHelper { /** - * Create an RDD for every partition column specified in the query. Note that for on-disk Hive - * tables, a data directory is created for each partition corresponding to keys specified using - * 'PARTITION BY'. + * Add filter expressions and column metadata to the HiveConf. This is meant to be called on the + * master - it's impractical to add filters during slave-local JobConf creation in HadoopRDD, + * since we would have to serialize the HiveTableScanOperator. */ - private def makePartitionRDD[T](rdd: RDD[T]): RDD[_] = { - val partitions = parts - val rdds = new Array[RDD[Any]](partitions.size) - - var i = 0 - partitions.foreach { part => - val partition = part.asInstanceOf[Partition] - val partDesc = Utilities.getPartitionDesc(partition) - val tablePath = partition.getPartitionPath.toString - - val ifc = partition.getInputFormatClass - .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] - val parts = createHadoopRdd(tablePath, ifc) - - val serializedHconf = XmlSerializer.serialize(localHconf, localHconf) - val partRDD = parts.mapPartitions { iter => - val hconf = XmlSerializer.deserialize(serializedHconf).asInstanceOf[HiveConf] - val deserializer = partDesc.getDeserializerClass().newInstance() - deserializer.initialize(hconf, partDesc.getProperties()) - - // Get partition field info - val partSpec = partDesc.getPartSpec() - val partProps = partDesc.getProperties() - - val partCols = partProps.getProperty(META_TABLE_PARTITION_COLUMNS) - // Partitioning keys are delimited by "/" - val partKeys = partCols.trim().split("/") - // 'partValues[i]' contains the value for the partitioning key at 'partKeys[i]'. - val partValues = new ArrayList[String] - partKeys.foreach { key => - if (partSpec == null) { - partValues.add(new String) - } else { - partValues.add(new String(partSpec.get(key))) - } + private def addFilterExprToConf(hiveConf: HiveConf, hiveTableScanOp: HiveTableScanOperator) { + val tableScanDesc = hiveTableScanOp.getConf() + if (tableScanDesc == null) return + + val rowSchema = hiveTableScanOp.getSchema + if (rowSchema != null) { + // Add column names to the HiveConf. + val columnNames = new StringBuilder + for (columnInfo <- rowSchema.getSignature()) { + if (columnNames.length > 0) { + columnNames.append(",") } - - val rowWithPartArr = new Array[Object](2) - // Map each tuple to a row object - iter.map { value => - val deserializedRow = deserializer.deserialize(value) // LazyStruct - rowWithPartArr.update(0, deserializedRow) - rowWithPartArr.update(1, partValues) - rowWithPartArr.asInstanceOf[Object] + columnNames.append(columnInfo.getInternalName()) + } + val columnNamesString = columnNames.toString() + hiveConf.set(Constants.LIST_COLUMNS, columnNamesString) + + // Add column types to the HiveConf. + val columnTypes = new StringBuilder + for (columnInfo <- rowSchema.getSignature()) { + if (columnTypes.length > 0) { + columnTypes.append(",") } + columnTypes.append(columnInfo.getType().getTypeName()) } - rdds(i) = partRDD.asInstanceOf[RDD[Any]] - i += 1 + val columnTypesString = columnTypes.toString() + hiveConf.set(Constants.LIST_COLUMN_TYPES, columnTypesString) } - // Even if we don't use any partitions, we still need an empty RDD - if (rdds.size == 0) { - SharkEnv.sc.makeRDD(Seq[Object]()) - } else { - new UnionRDD(rdds(0).context, rdds) - } - } - private def createHadoopRdd(path: String, ifc: Class[InputFormat[Writable, Writable]]) - : RDD[Writable] = { - val conf = new JobConf(localHconf) - if (tableDesc != null) { - Utilities.copyTableJobPropertiesToConf(tableDesc, conf) - } - new HiveInputFormat() { - def doPushFilters() { - pushFilters(conf, hiveOp) - } - }.doPushFilters() - FileInputFormat.setInputPaths(conf, path) - val bufferSize = System.getProperty("spark.buffer.size", "65536") - conf.set("io.file.buffer.size", bufferSize) - - // Set s3/s3n credentials. Setting them in conf ensures the settings propagate - // from Spark's master all the way to Spark's slaves. - var s3varsSet = false - val s3vars = Seq("fs.s3n.awsAccessKeyId", "fs.s3n.awsSecretAccessKey", - "fs.s3.awsAccessKeyId", "fs.s3.awsSecretAccessKey").foreach { variableName => - if (localHconf.get(variableName) != null) { - s3varsSet = true - conf.set(variableName, localHconf.get(variableName)) - } - } + // Push down predicate filters. + val filterExprNode = tableScanDesc.getFilterExpr() + if (filterExprNode != null) { + val filterText = filterExprNode.getExprString() + hiveConf.set(TableScanDesc.FILTER_TEXT_CONF_STR, filterText) + logDebug("Filter text: " + filterText) - // If none of the s3 credentials are set in Hive conf, try use the environmental - // variables for credentials. - if (!s3varsSet) { - Utils.setAwsCredentials(conf) + val filterExprNodeSerialized = Utilities.serializeExpression(filterExprNode) + hiveConf.set(TableScanDesc.FILTER_EXPR_CONF_STR, filterExprNodeSerialized) + logDebug("Filter expression: " + filterExprNodeSerialized) } - - // Choose the minimum number of splits. If mapred.map.tasks is set, use that unless - // it is smaller than what Spark suggests. - val minSplits = math.max(localHconf.getInt("mapred.map.tasks", 1), SharkEnv.sc.defaultMinSplits) - val rdd = SharkEnv.sc.hadoopRDD(conf, ifc, classOf[Writable], classOf[Writable], minSplits) - - // Only take the value (skip the key) because Hive works only with values. - rdd.map(_._2) } + } diff --git a/src/main/scala/shark/execution/TerminalOperator.scala b/src/main/scala/shark/execution/TerminalOperator.scala index 1a6400d7..7aa8afc8 100755 --- a/src/main/scala/shark/execution/TerminalOperator.scala +++ b/src/main/scala/shark/execution/TerminalOperator.scala @@ -23,6 +23,7 @@ import scala.reflect.BeanProperty import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.exec.{FileSinkOperator => HiveFileSinkOperator} +import org.apache.hadoop.hive.ql.plan.FileSinkDesc /** @@ -31,7 +32,7 @@ import org.apache.hadoop.hive.ql.exec.{FileSinkOperator => HiveFileSinkOperator} * - cache query output * - return query as RDD directly (without materializing it) */ -class TerminalOperator extends UnaryOperator[HiveFileSinkOperator] { +class TerminalOperator extends UnaryOperator[FileSinkDesc] { // Create a local copy of hconf and hiveSinkOp so we can XML serialize it. @BeanProperty var localHiveOp: HiveFileSinkOperator = _ @@ -39,12 +40,12 @@ class TerminalOperator extends UnaryOperator[HiveFileSinkOperator] { @BeanProperty val now = new Date() override def initializeOnMaster() { + super.initializeOnMaster() localHconf = super.hconf // Set parent to null so we won't serialize the entire query plan. - hiveOp.setParentOperators(null) - hiveOp.setChildOperators(null) - hiveOp.setInputObjInspectors(null) - localHiveOp = hiveOp + localHiveOp.setParentOperators(null) + localHiveOp.setChildOperators(null) + localHiveOp.setInputObjInspectors(null) } override def initializeOnSlave() { diff --git a/src/main/scala/shark/execution/UDTFOperator.scala b/src/main/scala/shark/execution/UDTFOperator.scala index db59f9cc..5782f370 100755 --- a/src/main/scala/shark/execution/UDTFOperator.scala +++ b/src/main/scala/shark/execution/UDTFOperator.scala @@ -23,14 +23,14 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.reflect.BeanProperty -import org.apache.hadoop.hive.ql.exec.{UDTFOperator => HiveUDTFOperator} import org.apache.hadoop.hive.ql.plan.UDTFDesc import org.apache.hadoop.hive.ql.udf.generic.Collector -import org.apache.hadoop.hive.serde2.objectinspector.{ ObjectInspector, - StandardStructObjectInspector, StructField, StructObjectInspector } +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.StructField -class UDTFOperator extends UnaryOperator[HiveUDTFOperator] { +class UDTFOperator extends UnaryOperator[UDTFDesc] { @BeanProperty var conf: UDTFDesc = _ @@ -38,9 +38,14 @@ class UDTFOperator extends UnaryOperator[HiveUDTFOperator] { @transient var soi: StandardStructObjectInspector = _ @transient var inputFields: JavaList[_ <: StructField] = _ @transient var collector: UDTFCollector = _ + @transient var outputObjInspector: ObjectInspector = _ override def initializeOnMaster() { - conf = hiveOp.getConf() + super.initializeOnMaster() + + conf = desc + + initializeOnSlave() } override def initializeOnSlave() { @@ -56,9 +61,11 @@ class UDTFOperator extends UnaryOperator[HiveUDTFOperator] { }.toArray objToSendToUDTF = new Array[java.lang.Object](inputFields.size) - val udtfOutputOI = conf.getGenericUDTF().initialize(udtfInputOIs) + outputObjInspector = conf.getGenericUDTF().initialize(udtfInputOIs) } + override def outputObjectInspector() = outputObjInspector + override def processPartition(split: Int, iter: Iterator[_]): Iterator[_] = { iter.flatMap { row => explode(row) diff --git a/src/main/scala/shark/execution/UnionOperator.scala b/src/main/scala/shark/execution/UnionOperator.scala index 2e46a004..e332739e 100755 --- a/src/main/scala/shark/execution/UnionOperator.scala +++ b/src/main/scala/shark/execution/UnionOperator.scala @@ -19,15 +19,13 @@ package shark.execution import java.util.{ArrayList, List => JavaList} -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.reflect.BeanProperty -import org.apache.hadoop.hive.ql.exec.{UnionOperator => HiveUnionOperator} +import org.apache.hadoop.hive.ql.plan.UnionDesc import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ReturnObjectInspectorResolver import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector import org.apache.hadoop.hive.serde2.objectinspector.StructField import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector @@ -40,23 +38,29 @@ import shark.execution.serialization.OperatorSerializationWrapper * A union operator. If the incoming data are of different type, the union * operator transforms the incoming data into the same type. */ -class UnionOperator extends NaryOperator[HiveUnionOperator] { +class UnionOperator extends NaryOperator[UnionDesc] { - @transient var parentFields: ArrayBuffer[JavaList[_ <: StructField]] = _ - @transient var parentObjInspectors: ArrayBuffer[StructObjectInspector] = _ + @transient var parentFields: Seq[JavaList[_ <: StructField]] = _ + @transient var parentObjInspectors: Seq[StructObjectInspector] = _ @transient var columnTypeResolvers: Array[ReturnObjectInspectorResolver] = _ + @transient var outputObjInspector: ObjectInspector = _ @BeanProperty var needsTransform: Array[Boolean] = _ @BeanProperty var numParents: Int = _ override def initializeOnMaster() { + super.initializeOnMaster() numParents = parentOperators.size - // Use reflection to get the needsTransform boolean array. - val needsTransformField = hiveOp.getClass.getDeclaredField("needsTransform") - needsTransformField.setAccessible(true) - needsTransform = needsTransformField.get(hiveOp).asInstanceOf[Array[Boolean]] - + // whether we need to do transformation for each parent + var parents = parentOperators.length + var outputOI = outputObjectInspector() + needsTransform = Array.tabulate[Boolean](objectInspectors.length) { i => + // ObjectInspectors created by the ObjectInspectorFactory, + // which take the same ref if equals + objectInspectors(i) != outputOI + } + initializeOnSlave() } @@ -82,18 +86,20 @@ class UnionOperator extends NaryOperator[HiveUnionOperator] { } val outputFieldOIs = columnTypeResolvers.map(_.get()) - val outputObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector( - columnNames, outputFieldOIs.toList) + outputObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector( + columnNames, outputFieldOIs.toList) // whether we need to do transformation for each parent // We reuse needsTransform from Hive because the comparison of object // inspectors are hard once we send object inspectors over the wire. needsTransform.zipWithIndex.filter(_._1).foreach { case(transform, p) => - logInfo("Union Operator needs to transform row from parent[%d] from %s to %s".format( - p, objectInspectors(p), outputObjInspector)) + logDebug("Union Operator needs to transform row from parent[%d] from %s to %s".format( + p, objectInspectors(p), outputObjInspector)) } } + override def outputObjectInspector() = outputObjInspector + /** * Override execute. The only thing we need to call is combineMultipleRdds(). */ diff --git a/src/main/scala/shark/execution/optimization/ColumnPruner.scala b/src/main/scala/shark/execution/optimization/ColumnPruner.scala index 4ab62194..38efb328 100644 --- a/src/main/scala/shark/execution/optimization/ColumnPruner.scala +++ b/src/main/scala/shark/execution/optimization/ColumnPruner.scala @@ -1,9 +1,26 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.execution.optimization import java.util.BitSet import java.util.{List => JList} -import scala.collection.JavaConversions.{asScalaBuffer, bufferAsJavaList, collectionAsScalaIterable} +import scala.collection.JavaConversions.{asScalaBuffer, collectionAsScalaIterable} import scala.collection.mutable.{Set, HashSet} import org.apache.hadoop.hive.ql.exec.GroupByPreShuffleOperator @@ -14,16 +31,16 @@ import org.apache.hadoop.hive.ql.plan.{FilterDesc, MapJoinDesc, ReduceSinkDesc} import shark.execution.{FilterOperator, JoinOperator, MapJoinOperator, Operator, ReduceSinkOperator, SelectOperator, TopOperator} -import shark.memstore2.{ColumnarStruct, TablePartitionIterator} class ColumnPruner(@transient op: TopOperator[_], @transient tbl: Table) extends Serializable { val columnsUsed: BitSet = { val colsToKeep = computeColumnsToKeep() - val allColumns = tbl.getAllCols().map(x => x.getName()) - var b = new BitSet() - for (i <- Range(0, allColumns.size()) if (colsToKeep.contains(allColumns(i)))) { + // No need to prune partition columns - Hive does that for us. + val allColumns = tbl.getCols().map(x => x.getName()) + val b = new BitSet() + for (i <- Range(0, allColumns.size) if colsToKeep.contains(allColumns(i))) { b.set(i, true) } b @@ -38,11 +55,15 @@ class ColumnPruner(@transient op: TopOperator[_], @transient tbl: Table) extends /** * Computes the column names that are referenced in the Query */ - private def computeColumnsToKeep(op: Operator[_], - cols: HashSet[String], parentOp: Operator[_] = null): Unit = { + private def computeColumnsToKeep( + op: Operator[_], + cols: HashSet[String], + parentOp: Operator[_] = null) { + def nullGuard[T](s: JList[T]): Seq[T] = { if (s == null) Seq[T]() else s } + op match { case selOp: SelectOperator => { val cnf:SelectDesc = selOp.getConf @@ -67,7 +88,7 @@ class ColumnPruner(@transient op: TopOperator[_], @transient tbl: Table) extends if (cnf != null) { val keyEvals = nullGuard(cnf.getKeyCols) val valEvals = nullGuard(cnf.getValueCols) - val evals = (HashSet() ++ keyEvals ++ valEvals) + val evals = HashSet() ++ keyEvals ++ valEvals cols ++= evals.flatMap(x => nullGuard(x.getCols)) } } @@ -76,7 +97,7 @@ class ColumnPruner(@transient op: TopOperator[_], @transient tbl: Table) extends if (cnf != null) { val keyEvals = cnf.getKeys.values val valEvals = cnf.getExprs.values - val evals = (HashSet() ++ keyEvals ++ valEvals) + val evals = HashSet() ++ keyEvals ++ valEvals cols ++= evals.flatMap(x => x).flatMap(x => nullGuard(x.getCols)) } } diff --git a/src/main/scala/shark/execution/package.scala b/src/main/scala/shark/execution/package.scala index e99b4766..f8251c8a 100755 --- a/src/main/scala/shark/execution/package.scala +++ b/src/main/scala/shark/execution/package.scala @@ -17,17 +17,17 @@ package shark +import scala.language.implicitConversions + import shark.execution.serialization.KryoSerializationWrapper import shark.execution.serialization.OperatorSerializationWrapper - package object execution { - type HiveOperator = org.apache.hadoop.hive.ql.exec.Operator[_] + type HiveDesc = java.io.Serializable // XXXDesc in Hive is the subclass of Serializable - implicit def opSerWrapper2op[T <: Operator[_ <: HiveOperator]]( + implicit def opSerWrapper2op[T <: Operator[_ <: HiveDesc]]( wrapper: OperatorSerializationWrapper[T]): T = wrapper.value implicit def kryoWrapper2object[T](wrapper: KryoSerializationWrapper[T]): T = wrapper.value } - diff --git a/src/main/scala/shark/execution/serialization/HiveConfSerializer.scala b/src/main/scala/shark/execution/serialization/HiveConfSerializer.scala deleted file mode 100644 index db612e4c..00000000 --- a/src/main/scala/shark/execution/serialization/HiveConfSerializer.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (C) 2012 The Regents of The University California. - * All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package shark.execution.serialization - -import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream -import java.io.DataInputStream -import java.io.DataOutputStream - -import com.ning.compress.lzf.LZFEncoder -import com.ning.compress.lzf.LZFDecoder - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.io.Text - - -object HiveConfSerializer { - - def serialize(hConf: HiveConf): Array[Byte] = { - val os = new ByteArrayOutputStream - val dos = new DataOutputStream(os) - val auxJars = hConf.getAuxJars() - Text.writeString(dos, if(auxJars == null) "" else auxJars) - hConf.write(dos) - LZFEncoder.encode(os.toByteArray()) - } - - def deserialize(b: Array[Byte]): HiveConf = { - val is = new ByteArrayInputStream(LZFDecoder.decode(b)) - val dis = new DataInputStream(is) - val auxJars = Text.readString(dis) - val conf = new HiveConf - conf.readFields(dis) - if(auxJars.equals("").unary_!) - conf.setAuxJars(auxJars) - conf - } -} diff --git a/src/main/scala/shark/execution/serialization/HiveStructDeserializer.scala b/src/main/scala/shark/execution/serialization/HiveStructDeserializer.scala index 9589a1e9..2a54fbf3 100644 --- a/src/main/scala/shark/execution/serialization/HiveStructDeserializer.scala +++ b/src/main/scala/shark/execution/serialization/HiveStructDeserializer.scala @@ -23,8 +23,6 @@ package org.apache.hadoop.hive.serde2.binarysortable import java.io.IOException import java.util.{ArrayList => JArrayList} -import scala.collection.JavaConversions._ - import org.apache.hadoop.hive.serde2.SerDeException import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoUtils} diff --git a/src/main/scala/shark/execution/serialization/HiveStructSerializer.scala b/src/main/scala/shark/execution/serialization/HiveStructSerializer.scala index 5b99ba09..1f5544fb 100644 --- a/src/main/scala/shark/execution/serialization/HiveStructSerializer.scala +++ b/src/main/scala/shark/execution/serialization/HiveStructSerializer.scala @@ -22,8 +22,6 @@ package org.apache.hadoop.hive.serde2.binarysortable import java.util.{List => JList} -import scala.collection.JavaConversions._ - import org.apache.hadoop.hive.serde2.objectinspector.{StructField, StructObjectInspector} diff --git a/src/main/scala/shark/execution/serialization/JavaSerializer.scala b/src/main/scala/shark/execution/serialization/JavaSerializer.scala index a98cb95c..df6ab31d 100644 --- a/src/main/scala/shark/execution/serialization/JavaSerializer.scala +++ b/src/main/scala/shark/execution/serialization/JavaSerializer.scala @@ -19,11 +19,12 @@ package shark.execution.serialization import java.nio.ByteBuffer +import org.apache.spark.SparkEnv import org.apache.spark.serializer.{JavaSerializer => SparkJavaSerializer} object JavaSerializer { - @transient val ser = new SparkJavaSerializer + @transient val ser = new SparkJavaSerializer(SparkEnv.get.conf) def serialize[T](o: T): Array[Byte] = { ser.newInstance().serialize(o).array() diff --git a/src/main/scala/shark/execution/serialization/KryoSerializer.scala b/src/main/scala/shark/execution/serialization/KryoSerializer.scala index c4764979..0532fbcc 100644 --- a/src/main/scala/shark/execution/serialization/KryoSerializer.scala +++ b/src/main/scala/shark/execution/serialization/KryoSerializer.scala @@ -19,8 +19,10 @@ package shark.execution.serialization import java.nio.ByteBuffer +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer.{KryoSerializer => SparkKryoSerializer} +import shark.SharkContext /** * Java object serialization using Kryo. This is much more efficient, but Kryo @@ -29,7 +31,10 @@ import org.apache.spark.serializer.{KryoSerializer => SparkKryoSerializer} */ object KryoSerializer { - @transient val ser = new SparkKryoSerializer + @transient lazy val ser: SparkKryoSerializer = { + val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new SparkKryoSerializer(sparkConf) + } def serialize[T](o: T): Array[Byte] = { ser.newInstance().serialize(o).array() diff --git a/src/main/scala/shark/execution/serialization/OperatorSerializationWrapper.scala b/src/main/scala/shark/execution/serialization/OperatorSerializationWrapper.scala index 858ce182..19e383c4 100644 --- a/src/main/scala/shark/execution/serialization/OperatorSerializationWrapper.scala +++ b/src/main/scala/shark/execution/serialization/OperatorSerializationWrapper.scala @@ -17,7 +17,7 @@ package shark.execution.serialization -import shark.execution.HiveOperator +import shark.execution.HiveDesc import shark.execution.Operator @@ -28,7 +28,7 @@ import shark.execution.Operator * * Use OperatorSerializationWrapper(operator) to create a wrapper. */ -class OperatorSerializationWrapper[T <: Operator[_ <: HiveOperator]] +class OperatorSerializationWrapper[T <: Operator[_ <: HiveDesc]] extends Serializable with shark.LogHelper { /** The operator we are going to serialize. */ @@ -69,9 +69,9 @@ class OperatorSerializationWrapper[T <: Operator[_ <: HiveOperator]] object OperatorSerializationWrapper { - def apply[T <: Operator[_ <: HiveOperator]](value: T): OperatorSerializationWrapper[T] = { + def apply[T <: Operator[_ <: HiveDesc]](value: T): OperatorSerializationWrapper[T] = { val wrapper = new OperatorSerializationWrapper[T] wrapper.value = value wrapper } -} \ No newline at end of file +} diff --git a/src/main/scala/shark/execution/serialization/ShuffleSerializer.scala b/src/main/scala/shark/execution/serialization/ShuffleSerializer.scala index b2a2d014..e4eba584 100644 --- a/src/main/scala/shark/execution/serialization/ShuffleSerializer.scala +++ b/src/main/scala/shark/execution/serialization/ShuffleSerializer.scala @@ -22,7 +22,9 @@ import java.nio.ByteBuffer import org.apache.hadoop.io.BytesWritable -import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerInstance, SerializationStream} +import org.apache.spark.SparkConf +import org.apache.spark.serializer.DeserializationStream +import org.apache.spark.serializer.{SerializationStream, Serializer, SerializerInstance} import shark.execution.{ReduceKey, ReduceKeyReduceSide} @@ -47,7 +49,11 @@ import shark.execution.{ReduceKey, ReduceKeyReduceSide} * into a hash table. We want to reduce the size of the hash table. Having the BytesWritable wrapper * would increase the size of the hash table by another 16 bytes per key-value pair. */ -class ShuffleSerializer extends Serializer { +class ShuffleSerializer(conf: SparkConf) extends Serializer { + + // A no-arg constructor since conf is not needed in this serializer. + def this() = this(null) + override def newInstance(): SerializerInstance = new ShuffleSerializerInstance } diff --git a/src/main/scala/shark/execution/serialization/XmlSerializer.scala b/src/main/scala/shark/execution/serialization/XmlSerializer.scala index 4c63efab..a533c812 100644 --- a/src/main/scala/shark/execution/serialization/XmlSerializer.scala +++ b/src/main/scala/shark/execution/serialization/XmlSerializer.scala @@ -17,8 +17,8 @@ package shark.execution.serialization -import java.beans.{XMLDecoder, XMLEncoder, PersistenceDelegate} -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectOutput, ObjectInput} +import java.beans.{XMLDecoder, XMLEncoder} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.ning.compress.lzf.{LZFEncoder, LZFDecoder} @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities.EnumDelegate import org.apache.hadoop.hive.ql.plan.GroupByDesc import org.apache.hadoop.hive.ql.plan.PlanUtils.ExpressionTypes -import shark.{SharkConfVars, SharkEnvSlave} +import shark.SharkConfVars /** diff --git a/src/main/scala/shark/memstore2/CachePolicy.scala b/src/main/scala/shark/memstore2/CachePolicy.scala new file mode 100644 index 00000000..27e29ff6 --- /dev/null +++ b/src/main/scala/shark/memstore2/CachePolicy.scala @@ -0,0 +1,227 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.memstore2 + +import java.util.concurrent.ConcurrentHashMap +import java.util.LinkedHashMap +import java.util.Map.Entry + +import scala.collection.JavaConversions._ + + +/** + * An general interface for pluggable cache eviction policies in Shark. + * One example of usage is to control persistance levels of RDDs that represent a table's + * Hive-partitions. + */ +trait CachePolicy[K, V] { + + protected var _loadFunc: (K => V) = _ + + protected var _evictionFunc: (K, V) => Unit = _ + + protected var _maxSize: Int = -1 + + def initialize( + strArgs: Array[String], + fallbackMaxSize: Int, + loadFunc: K => V, + evictionFunc: (K, V) => Unit) { + _loadFunc = loadFunc + _evictionFunc = evictionFunc + + // By default, only initialize the `maxSize` from user specifications. + strArgs.size match { + case 0 => _maxSize = fallbackMaxSize + case 1 => _maxSize = strArgs.head.toInt + case _ => + throw new Exception("Accpted format: %s(maxSize: Int)".format(this.getClass.getName)) + } + require(maxSize > 0, "Size given to cache eviction policy must be > 1") + } + + def notifyPut(key: K, value: V): Unit + + def notifyRemove(key: K): Unit + + def notifyGet(key: K): Unit + + def keysOfCachedEntries: Seq[K] + + def maxSize: Int = _maxSize + + // TODO(harvey): Call this in Shark's handling of ALTER TABLE TBLPROPERTIES. + def maxSize_= (newMaxSize: Int) = _maxSize = newMaxSize + + def hitRate: Double + + def evictionCount: Long +} + + +object CachePolicy { + + def instantiateWithUserSpecs[K, V]( + str: String, + fallbackMaxSize: Int, + loadFunc: K => V, + evictionFunc: (K, V) => Unit): CachePolicy[K, V] = { + val firstParenPos = str.indexOf('(') + if (firstParenPos == -1) { + val policy = Class.forName(str).newInstance.asInstanceOf[CachePolicy[K, V]] + policy.initialize(Array.empty[String], fallbackMaxSize, loadFunc, evictionFunc) + return policy + } else { + val classStr = str.slice(0, firstParenPos) + val strArgs = str.substring(firstParenPos + 1, str.lastIndexOf(')')).split(',') + val policy = Class.forName(classStr).newInstance.asInstanceOf[CachePolicy[K, V]] + policy.initialize(strArgs, fallbackMaxSize, loadFunc, evictionFunc) + return policy + } + } +} + + +/** + * A cache that never evicts entries. + */ +class CacheAllPolicy[K, V] extends CachePolicy[K, V] { + + // Track the entries in the cache, so that keysOfCachedEntries() returns a valid result. + var cache = new ConcurrentHashMap[K, V]() + + override def notifyPut(key: K, value: V) = cache.put(key, value) + + override def notifyRemove(key: K) = cache.remove(key) + + override def notifyGet(key: K) = Unit + + override def keysOfCachedEntries: Seq[K] = cache.keySet.toSeq + + override def hitRate = 1.0 + + override def evictionCount = 0L +} + + +class LRUCachePolicy[K, V] extends LinkedMapBasedPolicy[K, V] { + + override def initialize( + strArgs: Array[String], + fallbackMaxSize: Int, + loadFunc: K => V, + evictionFunc: (K, V) => Unit) { + super.initialize(strArgs, fallbackMaxSize, loadFunc, evictionFunc) + _cache = new LinkedMapCache(true /* evictUsingAccessOrder */) + } + +} + + +class FIFOCachePolicy[K, V] extends LinkedMapBasedPolicy[K, V] { + + override def initialize( + strArgs: Array[String], + fallbackMaxSize: Int, + loadFunc: K => V, + evictionFunc: (K, V) => Unit) { + super.initialize(strArgs, fallbackMaxSize, loadFunc, evictionFunc) + _cache = new LinkedMapCache() + } + +} + + +sealed abstract class LinkedMapBasedPolicy[K, V] extends CachePolicy[K, V] { + + class LinkedMapCache(evictUsingAccessOrder: Boolean = false) + extends LinkedHashMap[K, V](maxSize, 0.75F, evictUsingAccessOrder) { + + override def removeEldestEntry(eldest: Entry[K, V]): Boolean = { + val shouldRemove = (size() > maxSize) + if (shouldRemove) { + _evictionFunc(eldest.getKey, eldest.getValue) + _evictionCount += 1 + } + return shouldRemove + } + } + + protected var _cache: LinkedMapCache = _ + protected var _isInitialized = false + protected var _hitCount: Long = 0L + protected var _missCount: Long = 0L + protected var _evictionCount: Long = 0L + + override def initialize( + strArgs: Array[String], + fallbackMaxSize: Int, + loadFunc: K => V, + evictionFunc: (K, V) => Unit) { + super.initialize(strArgs, fallbackMaxSize, loadFunc, evictionFunc) + _isInitialized = true + } + + override def notifyPut(key: K, value: V): Unit = { + assert(_isInitialized, "Must initialize() %s.".format(this.getClass.getName)) + this.synchronized { + val oldValue = _cache.put(key, value) + if (oldValue != null) { + _evictionFunc(key, oldValue) + _evictionCount += 1 + } + } + } + + override def notifyRemove(key: K): Unit = { + assert(_isInitialized, "Must initialize() %s.".format(this.getClass.getName)) + this.synchronized { _cache.remove(key) } + } + + override def notifyGet(key: K): Unit = { + assert(_isInitialized, "Must initialize() %s.".format(this.getClass.getName)) + this.synchronized { + if (_cache.contains(key)) { + _cache.get(key) + _hitCount += 1L + } else { + val loadedValue = _loadFunc(key) + _cache.put(key, loadedValue) + _missCount += 1L + } + } + } + + override def keysOfCachedEntries: Seq[K] = { + assert(_isInitialized, "Must initialize() LRUCachePolicy.") + this.synchronized { + return _cache.keySet.toSeq + } + } + + override def hitRate: Double = { + this.synchronized { + val requestCount = _missCount + _hitCount + val rate = if (requestCount == 0L) 1.0 else (_hitCount.toDouble / requestCount) + return rate + } + } + + override def evictionCount = _evictionCount + +} diff --git a/src/main/scala/shark/memstore2/CacheType.scala b/src/main/scala/shark/memstore2/CacheType.scala index 13115415..ed1e1735 100644 --- a/src/main/scala/shark/memstore2/CacheType.scala +++ b/src/main/scala/shark/memstore2/CacheType.scala @@ -17,28 +17,50 @@ package shark.memstore2 +import shark.LogHelper -object CacheType extends Enumeration { +/* + * Enumerations and static helper functions for caches supported by Shark. + */ +object CacheType extends Enumeration with LogHelper { + + /* + * The CacheTypes: + * - MEMORY: Stored in memory and on disk (i.e., cache is write-through). Persistent across Shark + * sessions. By default, all such tables are reloaded into memory on restart. + * - MEMORY_ONLY: Stored only in memory and dropped at the end of each Shark session. + * - TACHYON: A distributed storage system that manages an in-memory cache for sharing files and + RDDs across cluster frameworks. + * - NONE: Stored on disk (e.g., HDFS) and managed by Hive. + */ type CacheType = Value - val none, heap, tachyon = Value + val MEMORY, MEMORY_ONLY, TACHYON, NONE = Value - def shouldCache(c: CacheType): Boolean = (c != none) + def shouldCache(c: CacheType): Boolean = (c != NONE) /** Get the cache type object from a string representation. */ def fromString(name: String): CacheType = { - if (name == null || name == "") { - none + if (name == null || name == "" || name.toLowerCase == "false") { + NONE } else if (name.toLowerCase == "true") { - heap + MEMORY } else { try { - withName(name.toLowerCase) + if (name.toUpperCase == "HEAP") { + // Interpret 'HEAP' as 'MEMORY' to ensure backwards compatibility with Shark 0.8.0. + logWarning("The 'HEAP' cache type name is deprecated. Use 'MEMORY' instead.") + MEMORY + } else { + // Try to use Scala's Enumeration::withName() to interpret 'name'. + withName(name.toUpperCase) + } } catch { case e: java.util.NoSuchElementException => throw new InvalidCacheTypeException(name) } } } - class InvalidCacheTypeException(name: String) extends Exception("Invalid cache type " + name) + class InvalidCacheTypeException(name: String) + extends Exception("Invalid string representation of cache type: '%s'".format(name)) } diff --git a/src/main/scala/shark/memstore2/ColumnarSerDe.scala b/src/main/scala/shark/memstore2/ColumnarSerDe.scala index 4c8bef76..79c6f282 100644 --- a/src/main/scala/shark/memstore2/ColumnarSerDe.scala +++ b/src/main/scala/shark/memstore2/ColumnarSerDe.scala @@ -51,7 +51,7 @@ class ColumnarSerDe extends SerDe with LogHelper { objectInspector = ColumnarStructObjectInspector(serDeParams) // This null check is needed because Hive's SemanticAnalyzer.genFileSinkPlan() creates - // an instance of the table's StructObjectInspector by creating an instance SerDe, which + // an instance of the table's StructObjectInspector by creating an instance of SerDe, which // it initializes by passing a 'null' argument for 'conf'. if (conf != null) { var partitionSize = { diff --git a/src/main/scala/shark/memstore2/ColumnarStructObjectInspector.scala b/src/main/scala/shark/memstore2/ColumnarStructObjectInspector.scala index 02f799fe..67a99612 100644 --- a/src/main/scala/shark/memstore2/ColumnarStructObjectInspector.scala +++ b/src/main/scala/shark/memstore2/ColumnarStructObjectInspector.scala @@ -27,8 +27,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo -import shark.{SharkConfVars} - class ColumnarStructObjectInspector(fields: JList[StructField]) extends StructObjectInspector { @@ -60,7 +58,8 @@ object ColumnarStructObjectInspector { for (i <- 0 until columnNames.size) { val typeInfo = columnTypes.get(i) val fieldOI = typeInfo.getCategory match { - case Category.PRIMITIVE => PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( + case Category.PRIMITIVE => + PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( typeInfo.asInstanceOf[PrimitiveTypeInfo].getPrimitiveCategory) case _ => LazyFactory.createLazyObjectInspector( typeInfo, serDeParams.getSeparators(), 1, serDeParams.getNullSequence(), diff --git a/src/main/scala/shark/memstore2/LazySimpleSerDeWrapper.scala b/src/main/scala/shark/memstore2/LazySimpleSerDeWrapper.scala new file mode 100644 index 00000000..2211d557 --- /dev/null +++ b/src/main/scala/shark/memstore2/LazySimpleSerDeWrapper.scala @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.memstore2 + +import java.util.{List => JList, Properties} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.serde2.{SerDe, SerDeStats} +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.io.Writable + + +class LazySimpleSerDeWrapper extends SerDe { + + val _lazySimpleSerDe = new LazySimpleSerDe() + + override def initialize(conf: Configuration, tbl: Properties) { + _lazySimpleSerDe.initialize(conf, tbl) + } + + override def deserialize(blob: Writable): Object = _lazySimpleSerDe.deserialize(blob) + + override def getSerDeStats(): SerDeStats = _lazySimpleSerDe.getSerDeStats() + + override def getObjectInspector: ObjectInspector = _lazySimpleSerDe.getObjectInspector + + override def getSerializedClass: Class[_ <: Writable] = _lazySimpleSerDe.getSerializedClass + + override def serialize(obj: Object, objInspector: ObjectInspector): Writable = { + _lazySimpleSerDe.serialize(obj, objInspector) + } + +} diff --git a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala index c180dd40..9d5ce7ab 100755 --- a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala +++ b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala @@ -17,104 +17,209 @@ package shark.memstore2 +import java.util.{HashMap=> JavaHashMap, Map => JavaMap} import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConversions._ -import scala.collection.mutable.ConcurrentMap +import scala.collection.concurrent + +import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.spark.rdd.{RDD, UnionRDD} -import org.apache.spark.storage.StorageLevel -import shark.SharkConfVars +import shark.{LogHelper, SharkEnv} +import shark.execution.RDDUtils +import shark.util.HiveUtils -class MemoryMetadataManager { +class MemoryMetadataManager extends LogHelper { - private val _keyToRdd: ConcurrentMap[String, RDD[_]] = - new ConcurrentHashMap[String, RDD[_]]() + // Set of tables, from databaseName.tableName to Table object. + private val _tables: concurrent.Map[String, Table] = + new ConcurrentHashMap[String, Table]() - private val _keyToStats: ConcurrentMap[String, collection.Map[Int, TablePartitionStats]] = - new ConcurrentHashMap[String, collection.Map[Int, TablePartitionStats]] + def isHivePartitioned(databaseName: String, tableName: String): Boolean = { + val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) + _tables.get(tableKey) match { + case Some(table) => table.isInstanceOf[PartitionedMemoryTable] + case None => false + } + } - def contains(key: String) = _keyToRdd.contains(key.toLowerCase) + def containsTable(databaseName: String, tableName: String): Boolean = { + _tables.contains(MemoryMetadataManager.makeTableKey(databaseName, tableName)) + } - def put(key: String, rdd: RDD[_]) { - _keyToRdd(key.toLowerCase) = rdd + def createMemoryTable( + databaseName: String, + tableName: String, + cacheMode: CacheType.CacheType): MemoryTable = { + val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) + val newTable = new MemoryTable(databaseName, tableName, cacheMode) + _tables.put(tableKey, newTable) + newTable } - def get(key: String): Option[RDD[_]] = _keyToRdd.get(key.toLowerCase) + def createPartitionedMemoryTable( + databaseName: String, + tableName: String, + cacheMode: CacheType.CacheType, + tblProps: JavaMap[String, String] + ): PartitionedMemoryTable = { + val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) + val newTable = new PartitionedMemoryTable(databaseName, tableName, cacheMode) + // Determine the cache policy to use and read any user-specified cache settings. + val cachePolicyStr = tblProps.getOrElse(SharkTblProperties.CACHE_POLICY.varname, + SharkTblProperties.CACHE_POLICY.defaultVal) + val maxCacheSize = tblProps.getOrElse(SharkTblProperties.MAX_PARTITION_CACHE_SIZE.varname, + SharkTblProperties.MAX_PARTITION_CACHE_SIZE.defaultVal).toInt + newTable.setPartitionCachePolicy(cachePolicyStr, maxCacheSize) + + _tables.put(tableKey, newTable) + newTable + } - def putStats(key: String, stats: collection.Map[Int, TablePartitionStats]) { - _keyToStats.put(key.toLowerCase, stats) + def getTable(databaseName: String, tableName: String): Option[Table] = { + _tables.get(MemoryMetadataManager.makeTableKey(databaseName, tableName)) } - def getStats(key: String): Option[collection.Map[Int, TablePartitionStats]] = { - _keyToStats.get(key.toLowerCase) + def getMemoryTable(databaseName: String, tableName: String): Option[MemoryTable] = { + val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) + val tableOpt = _tables.get(tableKey) + if (tableOpt.isDefined) { + assert(tableOpt.get.isInstanceOf[MemoryTable], + "getMemoryTable() called for a partitioned table.") + } + tableOpt.asInstanceOf[Option[MemoryTable]] } - /** - * Find all keys that are strings. Used to drop tables after exiting. - */ - def getAllKeyStrings(): Seq[String] = { - _keyToRdd.keys.collect { case k: String => k } toSeq + def getPartitionedTable( + databaseName: String, + tableName: String): Option[PartitionedMemoryTable] = { + val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) + val tableOpt = _tables.get(tableKey) + if (tableOpt.isDefined) { + assert(tableOpt.get.isInstanceOf[PartitionedMemoryTable], + "getPartitionedTable() called for a non-partitioned table.") + } + tableOpt.asInstanceOf[Option[PartitionedMemoryTable]] + } + + def renameTable(databaseName: String, oldName: String, newName: String) { + if (containsTable(databaseName, oldName)) { + val oldTableKey = MemoryMetadataManager.makeTableKey(databaseName, oldName) + val newTableKey = MemoryMetadataManager.makeTableKey(databaseName, newName) + + val tableValueEntry = _tables.remove(oldTableKey).get + tableValueEntry.tableName = newTableKey + + _tables.put(newTableKey, tableValueEntry) + } } /** - * Used to drop an RDD from the Spark in-memory cache and/or disk. All metadata - * (e.g. entry in '_keyToStats') about the RDD that's tracked by Shark is deleted as well. + * Used to drop a table from Spark in-memory cache and/or disk. All metadata is deleted as well. + * + * Note that this is always used in conjunction with a dropTableFromMemory() for handling + *'shark.cache' property changes in an ALTER TABLE command, or to finish off a DROP TABLE command + * after the table has been deleted from the Hive metastore. * - * @param key Used to fetch the an RDD value from '_keyToRDD'. - * @return Option::isEmpty() is true if there is no RDD value corresponding to 'key' in - * '_keyToRDD'. Otherwise, returns a reference to the RDD that was unpersist()'ed. + * @return Option::isEmpty() is true of there is no MemoryTable (and RDD) corresponding to 'key' + * in _keyToMemoryTable. For tables that are Hive-partitioned, the RDD returned will be a + * UnionRDD comprising RDDs that back the table's Hive-partitions. */ - def unpersist(key: String): Option[RDD[_]] = { - def unpersistRDD(rdd: RDD[_]): Unit = { - rdd match { - case u: UnionRDD[_] => { - // Recursively unpersist() all RDDs that compose the UnionRDD. - u.unpersist() - u.rdds.foreach { - r => unpersistRDD(r) - } + def removeTable( + databaseName: String, + tableName: String): Option[RDD[_]] = { + val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) + val tableValueOpt: Option[Table] = _tables.remove(tableKey) + tableValueOpt.flatMap(tableValue => MemoryMetadataManager.unpersistRDDsForTable(tableValue)) + } + + def shutdown() { + val db = Hive.get() + for (table <- _tables.values) { + table.cacheMode match { + case CacheType.MEMORY => { + dropTableFromMemory(db, table.databaseName, table.tableName) + } + case CacheType.MEMORY_ONLY => HiveUtils.dropTableInHive(table.tableName, db.getConf) + case _ => { + // No need to handle Hive or Tachyon tables, which are persistent and managed by their + // respective systems. + Unit } - case r => r.unpersist() } } - // Remove RDD's entry from Shark metadata. This also fetches a reference to the RDD object - // corresponding to the argument for 'key'. - val rddValue = _keyToRdd.remove(key.toLowerCase()) - _keyToStats.remove(key) - // Unpersist the RDD using the nested helper fn above. - rddValue match { - case Some(rdd) => unpersistRDD(rdd) - case None => Unit + } + + /** + * Drops a table from the Shark cache. However, Shark properties needed for table recovery + * (see TableRecovery#reloadRdds()) won't be removed. + * After this method completes, the table can still be scanned from disk. + */ + def dropTableFromMemory( + db: Hive, + databaseName: String, + tableName: String) { + getTable(databaseName, tableName).foreach { sharkTable => + db.setCurrentDatabase(databaseName) + val hiveTable = db.getTable(databaseName, tableName) + // Refresh the Hive `db`. + db.alterTable(tableName, hiveTable) + // Unpersist the table's RDDs from memory. + removeTable(databaseName, tableName) } - rddValue } } object MemoryMetadataManager { - /** Return a StorageLevel corresponding to its String name. */ - def getStorageLevelFromString(s: String): StorageLevel = { - if (s == null || s == "") { - getStorageLevelFromString(SharkConfVars.STORAGE_LEVEL.defaultVal) - } else { - s.toUpperCase match { - case "NONE" => StorageLevel.NONE - case "DISK_ONLY" => StorageLevel.DISK_ONLY - case "DISK_ONLY_2" => StorageLevel.DISK_ONLY_2 - case "MEMORY_ONLY" => StorageLevel.MEMORY_ONLY - case "MEMORY_ONLY_2" => StorageLevel.MEMORY_ONLY_2 - case "MEMORY_ONLY_SER" => StorageLevel.MEMORY_ONLY_SER - case "MEMORY_ONLY_SER_2" => StorageLevel.MEMORY_ONLY_SER_2 - case "MEMORY_AND_DISK" => StorageLevel.MEMORY_AND_DISK - case "MEMORY_AND_DISK_2" => StorageLevel.MEMORY_AND_DISK_2 - case "MEMORY_AND_DISK_SER" => StorageLevel.MEMORY_AND_DISK_SER - case "MEMORY_AND_DISK_SER_2" => StorageLevel.MEMORY_AND_DISK_SER_2 - case _ => throw new IllegalArgumentException("Unrecognized storage level: " + s) + def unpersistRDDsForTable(table: Table): Option[RDD[_]] = { + table match { + case partitionedTable: PartitionedMemoryTable => { + // unpersist() all RDDs for all Hive-partitions. + val unpersistedRDDs = partitionedTable.keyToPartitions.values.map(rdd => + RDDUtils.unpersistRDD(rdd)).asInstanceOf[Seq[RDD[Any]]] + if (unpersistedRDDs.size > 0) { + val unionedRDD = new UnionRDD(unpersistedRDDs.head.context, unpersistedRDDs) + Some(unionedRDD) + } else { + None + } } + case memoryTable: MemoryTable => Some(RDDUtils.unpersistRDD(memoryTable.getRDD.get)) + } + } + + // Returns a key of the form "databaseName.tableName" that uniquely identifies a Shark table. + // For example, it's used to track a table's RDDs in MemoryMetadataManager and table paths in the + // Tachyon table warehouse. + def makeTableKey(databaseName: String, tableName: String): String = { + (databaseName + '.' + tableName).toLowerCase + } + + /** + * Return a representation of the partition key in the string format: + * 'col1=value1/col2=value2/.../colN=valueN' + */ + def makeHivePartitionKeyStr( + partitionCols: Seq[String], + partColToValue: JavaMap[String, String]): String = { + partitionCols.map(col => "%s=%s".format(col, partColToValue(col))).mkString("/") + } + + /** + * Returns a (partition column name -> value) mapping by parsing a `keyStr` of the format + * 'col1=value1/col2=value2/.../colN=valueN', created by makeHivePartitionKeyStr() above. + */ + def parseHivePartitionKeyStr(keyStr: String): JavaMap[String, String] = { + val partitionSpec = new JavaHashMap[String, String]() + for (pair <- keyStr.split("/")) { + val pairSplit = pair.split("=") + partitionSpec.put(pairSplit(0), pairSplit(1)) } + partitionSpec } } diff --git a/src/main/scala/shark/memstore2/MemoryTable.scala b/src/main/scala/shark/memstore2/MemoryTable.scala new file mode 100644 index 00000000..1a971d4c --- /dev/null +++ b/src/main/scala/shark/memstore2/MemoryTable.scala @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.memstore2 + +import org.apache.spark.rdd.RDD + +import scala.collection.mutable.{Buffer, HashMap} + +import shark.execution.RDDUtils + + +/** + * A metadata container for a table in Shark that's backed by an RDD. + */ +private[shark] class MemoryTable( + databaseName: String, + tableName: String, + cacheMode: CacheType.CacheType) + extends Table(databaseName, tableName, cacheMode) { + + private var _rddValueOpt: Option[RDDValue] = None + + /** + * Sets the RDD and stats fields the `_rddValueOpt`. Used for INSERT/LOAD OVERWRITE. + * @param newRDD The table's data. + * @param newStats Stats for each TablePartition in `newRDD`. + * @return The previous (RDD, stats) pair for this table. + */ + def put( + newRDD: RDD[TablePartition], + newStats: collection.Map[Int, TablePartitionStats] = new HashMap[Int, TablePartitionStats]() + ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = { + val prevRDDAndStatsOpt = _rddValueOpt.map(_.toTuple) + if (_rddValueOpt.isDefined) { + _rddValueOpt.foreach { rddValue => + rddValue.rdd = newRDD + rddValue.stats = newStats + } + } else { + _rddValueOpt = Some(new RDDValue(newRDD, newStats)) + } + prevRDDAndStatsOpt + } + + /** + * Used for append operations, such as INSERT and LOAD INTO. + * + * @param newRDD Data to append to the table. + * @param newStats Stats for each TablePartition in `newRDD`. + * @return The previous (RDD, stats) pair for this table. + */ + def update( + newRDD: RDD[TablePartition], + newStats: Buffer[(Int, TablePartitionStats)] + ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = { + val prevRDDAndStatsOpt = _rddValueOpt.map(_.toTuple) + if (_rddValueOpt.isDefined) { + val (prevRDD, prevStats) = (prevRDDAndStatsOpt.get._1, prevRDDAndStatsOpt.get._2) + val updatedRDDValue = _rddValueOpt.get + updatedRDDValue.rdd = RDDUtils.unionAndFlatten(prevRDD, newRDD) + updatedRDDValue.stats = Table.mergeStats(newStats, prevStats).toMap + } else { + put(newRDD, newStats.toMap) + } + prevRDDAndStatsOpt + } + + def getRDD = _rddValueOpt.map(_.rdd) + + def getStats = _rddValueOpt.map(_.stats) + +} diff --git a/src/main/scala/shark/memstore2/PartitionedMemoryTable.scala b/src/main/scala/shark/memstore2/PartitionedMemoryTable.scala new file mode 100644 index 00000000..b6bd8ae6 --- /dev/null +++ b/src/main/scala/shark/memstore2/PartitionedMemoryTable.scala @@ -0,0 +1,151 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.memstore2 + +import java.util.concurrent.{ConcurrentHashMap => ConcurrentJavaHashMap} + +import scala.collection.JavaConversions._ +import scala.collection.concurrent +import scala.collection.mutable.{Buffer, HashMap} + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +import shark.execution.RDDUtils + + +/** + * A metadata container for partitioned Shark table backed by RDDs. + * + * Note that a Hive-partition of a table is different from an RDD partition. Each Hive-partition + * is stored as a subdirectory of the table subdirectory in the warehouse directory + * (e.g. '/user/hive/warehouse'). So, every Hive-Partition is loaded into Shark as an RDD. + */ +private[shark] +class PartitionedMemoryTable( + databaseName: String, + tableName: String, + cacheMode: CacheType.CacheType) + extends Table(databaseName, tableName, cacheMode) { + + // A map from the Hive-partition key to the RDD that contains contents of that partition. + // The conventional string format for the partition key, 'col1=value1/col2=value2/...', can be + // computed using MemoryMetadataManager#makeHivePartitionKeyStr(). + private val _keyToPartitions: concurrent.Map[String, RDDValue] = + new ConcurrentJavaHashMap[String, RDDValue]() + + // The eviction policy for this table's cached Hive-partitions. An example of how this + // can be set from the CLI: + // `TBLPROPERTIES("shark.partition.cachePolicy", "LRUCachePolicy")`. + // If 'None', then all partitions will be put in memory. + // + // Since RDDValue is mutable, entries maintained by a CachePolicy's underlying data structure, + // such as the LinkedHashMap for LRUCachePolicy, can be updated without causing an eviction. + // The value entires for a single key in + // `_keyToPartitions` and `_cachePolicy` will reference the same RDDValue object. + private var _cachePolicy: CachePolicy[String, RDDValue] = _ + + def containsPartition(partitionKey: String): Boolean = _keyToPartitions.contains(partitionKey) + + def getPartition(partitionKey: String): Option[RDD[TablePartition]] = { + getPartitionAndStats(partitionKey).map(_._1) + } + + def getStats(partitionKey: String): Option[collection.Map[Int, TablePartitionStats]] = { + getPartitionAndStats(partitionKey).map(_._2) + } + + def getPartitionAndStats( + partitionKey: String + ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = { + val rddValueOpt: Option[RDDValue] = _keyToPartitions.get(partitionKey) + if (rddValueOpt.isDefined) _cachePolicy.notifyGet(partitionKey) + rddValueOpt.map(_.toTuple) + } + + def putPartition( + partitionKey: String, + newRDD: RDD[TablePartition], + newStats: collection.Map[Int, TablePartitionStats] = new HashMap[Int, TablePartitionStats]() + ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = { + val rddValueOpt = _keyToPartitions.get(partitionKey) + val prevRDDAndStats = rddValueOpt.map(_.toTuple) + val newRDDValue = new RDDValue(newRDD, newStats) + _keyToPartitions.put(partitionKey, newRDDValue) + _cachePolicy.notifyPut(partitionKey, newRDDValue) + prevRDDAndStats + } + + def updatePartition( + partitionKey: String, + newRDD: RDD[TablePartition], + newStats: Buffer[(Int, TablePartitionStats)] + ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = { + val prevRDDAndStatsOpt = getPartitionAndStats(partitionKey) + if (prevRDDAndStatsOpt.isDefined) { + val (prevRDD, prevStats) = (prevRDDAndStatsOpt.get._1, prevRDDAndStatsOpt.get._2) + // This is an update of an old value, so update the RDDValue's `rdd` entry. + // Don't notify the `_cachePolicy`. Assumes that getPartition() has already been called to + // obtain the value of the previous RDD. + // An RDD update refers to the RDD created from an INSERT. + val updatedRDDValue = _keyToPartitions.get(partitionKey).get + updatedRDDValue.rdd = RDDUtils.unionAndFlatten(prevRDD, newRDD) + updatedRDDValue.stats = Table.mergeStats(newStats, prevStats).toMap + } else { + // No previous RDDValue entry currently exists for `partitionKey`, so add one. + putPartition(partitionKey, newRDD, newStats.toMap) + } + prevRDDAndStatsOpt + } + + def removePartition( + partitionKey: String + ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = { + val rddRemoved = _keyToPartitions.remove(partitionKey) + if (rddRemoved.isDefined) { + _cachePolicy.notifyRemove(partitionKey) + } + rddRemoved.map(_.toTuple) + } + + /** Returns an immutable view of (partition key -> RDD) mappings to external callers */ + def keyToPartitions: collection.immutable.Map[String, RDD[TablePartition]] = { + _keyToPartitions.mapValues(_.rdd).toMap + } + + def setPartitionCachePolicy(cachePolicyStr: String, fallbackMaxSize: Int) { + // The loadFunc will upgrade the persistence level of the RDD to the preferred storage level. + val loadFunc: String => RDDValue = (partitionKey: String) => { + val rddValue = _keyToPartitions.get(partitionKey).get + if (cacheMode == CacheType.MEMORY) { + rddValue.rdd.persist(StorageLevel.MEMORY_AND_DISK) + } + rddValue + } + // The evictionFunc will unpersist the RDD. + val evictionFunc: (String, RDDValue) => Unit = (partitionKey, rddValue) => { + RDDUtils.unpersistRDD(rddValue.rdd) + } + val newPolicy = CachePolicy.instantiateWithUserSpecs[String, RDDValue]( + cachePolicyStr, fallbackMaxSize, loadFunc, evictionFunc) + _cachePolicy = newPolicy + } + + def cachePolicy: CachePolicy[String, RDDValue] = _cachePolicy + +} diff --git a/src/main/scala/shark/memstore2/SharkTblProperties.scala b/src/main/scala/shark/memstore2/SharkTblProperties.scala new file mode 100644 index 00000000..befc91d1 --- /dev/null +++ b/src/main/scala/shark/memstore2/SharkTblProperties.scala @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.memstore2 + +import java.util.{Map => JavaMap} + + +/** + * Collection of static fields and helpers for table properties (i.e., from A + * CREATE TABLE TBLPROPERTIES( ... ) used by Shark. + */ +object SharkTblProperties { + + case class TableProperty(varname: String, defaultVal: String) + + // Class name of the default cache policy used to manage partition evictions for cached, + // Hive-partitioned tables. + val CACHE_POLICY = new TableProperty("shark.cache.policy", "shark.memstore2.CacheAllPolicy") + + // Maximum size - in terms of the number of objects - of the cache specified by the + // "shark.cache.partition.cachePolicy" property above. + val MAX_PARTITION_CACHE_SIZE = new TableProperty("shark.cache.policy.maxSize", "10") + + // Default value for the "shark.cache" table property + val CACHE_FLAG = new TableProperty("shark.cache", "true") + + def getOrSetDefault(tblProps: JavaMap[String, String], variable: TableProperty): String = { + if (!tblProps.containsKey(variable.varname)) { + tblProps.put(variable.varname, variable.defaultVal) + } + tblProps.get(variable.varname) + } + + /** + * Returns value for the `variable` table property. If a value isn't present in `tblProps`, then + * the default for `variable` will be returned. + */ + def initializeWithDefaults( + tblProps: JavaMap[String, String], + isPartitioned: Boolean = false): JavaMap[String, String] = { + tblProps.put(CACHE_FLAG.varname, CACHE_FLAG.defaultVal) + if (isPartitioned) { + tblProps.put(CACHE_POLICY.varname, CACHE_POLICY.defaultVal) + } + tblProps + } + + def removeSharkProperties(tblProps: JavaMap[String, String]) { + tblProps.remove(CACHE_FLAG.varname) + tblProps.remove(CACHE_POLICY.varname) + tblProps.remove(MAX_PARTITION_CACHE_SIZE.varname) + } +} diff --git a/src/main/scala/shark/memstore2/Table.scala b/src/main/scala/shark/memstore2/Table.scala new file mode 100644 index 00000000..ae7f451f --- /dev/null +++ b/src/main/scala/shark/memstore2/Table.scala @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.memstore2 + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.rdd.RDD + +import scala.collection.mutable.Buffer + + +/** + * A container for table metadata managed by Shark and Spark. Subclasses are responsible for + * how RDDs are set, stored, and accessed. + * + * @param databaseName Namespace for this table. + * @param tableName Name of this table. + * @param cacheMode Type of memory storage used for the table (e.g., the Spark block manager). + */ +private[shark] abstract class Table( + var databaseName: String, + var tableName: String, + var cacheMode: CacheType.CacheType) { + + /** + * A mutable wrapper for an RDD and stats for its partitions. + */ + class RDDValue( + var rdd: RDD[TablePartition], + var stats: collection.Map[Int, TablePartitionStats]) { + + def toTuple = (rdd, stats) + } +} + +object Table { + + /** + * Merges contents of `otherStatsMaps` into `targetStatsMap`. + */ + def mergeStats( + targetStatsMap: Buffer[(Int, TablePartitionStats)], + otherStatsMap: Iterable[(Int, TablePartitionStats)] + ): Buffer[(Int, TablePartitionStats)] = { + val targetStatsMapSize = targetStatsMap.size + for ((otherIndex, tableStats) <- otherStatsMap) { + targetStatsMap.append((otherIndex + targetStatsMapSize, tableStats)) + } + targetStatsMap + } +} diff --git a/src/main/scala/shark/memstore2/TablePartition.scala b/src/main/scala/shark/memstore2/TablePartition.scala index 61235e85..ba8370a7 100644 --- a/src/main/scala/shark/memstore2/TablePartition.scala +++ b/src/main/scala/shark/memstore2/TablePartition.scala @@ -60,8 +60,6 @@ class TablePartition(private var _numRows: Long, private var _columns: Array[Byt buffer } - // TODO: Add column pruning to TablePartition for creating a TablePartitionIterator. - /** * Return an iterator for the partition. */ @@ -76,9 +74,9 @@ class TablePartition(private var _numRows: Long, private var _columns: Array[Byt def prunedIterator(columnsUsed: BitSet) = { val columnIterators: Array[ColumnIterator] = _columns.map { case buffer: ByteBuffer => - val iter = ColumnIterator.newIterator(buffer) - iter + ColumnIterator.newIterator(buffer) case _ => + // The buffer might be null if it is pruned in Tachyon. null } new TablePartitionIterator(_numRows, columnIterators, columnsUsed) diff --git a/src/main/scala/shark/memstore2/TablePartitionBuilder.scala b/src/main/scala/shark/memstore2/TablePartitionBuilder.scala index cdd2843d..8614c070 100644 --- a/src/main/scala/shark/memstore2/TablePartitionBuilder.scala +++ b/src/main/scala/shark/memstore2/TablePartitionBuilder.scala @@ -18,10 +18,10 @@ package shark.memstore2 import java.io.{DataInput, DataOutput} -import java.util.{List => JList} + +import scala.collection.JavaConversions._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector -import org.apache.hadoop.hive.serde2.objectinspector.StructField import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector import org.apache.hadoop.io.Writable @@ -33,19 +33,22 @@ import shark.memstore2.column.ColumnBuilder * partition of data into columnar format and to generate a TablePartition. */ class TablePartitionBuilder( - oi: StructObjectInspector, + ois: Seq[ObjectInspector], initialColumnSize: Int, - shouldCompress: Boolean = true) + shouldCompress: Boolean) extends Writable { - var numRows: Long = 0 - val fields: JList[_ <: StructField] = oi.getAllStructFieldRefs + def this(oi: StructObjectInspector, initialColumnSize: Int, shouldCompress: Boolean = true) = { + this(oi.getAllStructFieldRefs.map(_.getFieldObjectInspector), initialColumnSize, shouldCompress) + } + + private var numRows: Long = 0 - val columnBuilders = Array.tabulate[ColumnBuilder[_]](fields.size) { i => - val columnBuilder = ColumnBuilder.create(fields.get(i).getFieldObjectInspector, shouldCompress) + private val columnBuilders: Array[ColumnBuilder[_]] = ois.map { oi => + val columnBuilder = ColumnBuilder.create(oi, shouldCompress) columnBuilder.initialize(initialColumnSize) columnBuilder - } + }.toArray def incrementRowCount() { numRows += 1 @@ -57,7 +60,7 @@ class TablePartitionBuilder( def stats: TablePartitionStats = new TablePartitionStats(columnBuilders.map(_.stats), numRows) - def build: TablePartition = new TablePartition(numRows, columnBuilders.map(_.build)) + def build(): TablePartition = new TablePartition(numRows, columnBuilders.map(_.build())) // We don't use these, but want to maintain Writable interface for SerDe override def write(out: DataOutput) {} diff --git a/src/main/scala/shark/memstore2/TablePartitionIterator.scala b/src/main/scala/shark/memstore2/TablePartitionIterator.scala index 71aabd7c..947cdd22 100644 --- a/src/main/scala/shark/memstore2/TablePartitionIterator.scala +++ b/src/main/scala/shark/memstore2/TablePartitionIterator.scala @@ -17,7 +17,6 @@ package shark.memstore2 -import java.nio.ByteBuffer import java.util.BitSet import shark.memstore2.column.ColumnIterator @@ -45,13 +44,13 @@ class TablePartitionIterator( private var _position: Long = 0 - def hasNext(): Boolean = _position < numRows + def hasNext: Boolean = _position < numRows def next(): ColumnarStruct = { _position += 1 var i = columnUsed.nextSetBit(0) while (i > -1) { - columnIterators(i).next + columnIterators(i).next() i = columnUsed.nextSetBit(i + 1) } _struct diff --git a/src/main/scala/shark/memstore2/TableRecovery.scala b/src/main/scala/shark/memstore2/TableRecovery.scala new file mode 100644 index 00000000..adf61061 --- /dev/null +++ b/src/main/scala/shark/memstore2/TableRecovery.scala @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.memstore2 + +import java.util.{HashMap => JavaHashMap} + +import scala.collection.JavaConversions.asScalaBuffer + +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.session.SessionState + +import shark.{LogHelper, SharkEnv} +import shark.util.QueryRewriteUtils + +/** + * Singleton used to reload RDDs upon server restarts. + */ +object TableRecovery extends LogHelper { + + val db = Hive.get() + + /** + * Loads any cached tables with MEMORY as its `shark.cache` property. + * @param cmdRunner The runner that is responsible for taking a cached table query and + * a) Creating the table metadata in Hive Meta Store + * b) Loading the table as an RDD in memory + * @see SharkServer for an example usage. + * @param console Optional SessionState.LogHelper used, if present, to log information about + the tables that get reloaded. + */ + def reloadRdds(cmdRunner: String => Unit, console: Option[SessionState.LogHelper] = None) { + // Filter for tables that should be reloaded into the cache. + val currentDbName = db.getCurrentDatabase() + for (databaseName <- db.getAllDatabases(); tableName <- db.getAllTables(databaseName)) { + val hiveTable = db.getTable(databaseName, tableName) + val tblProps = hiveTable.getParameters + val cacheMode = CacheType.fromString(tblProps.get(SharkTblProperties.CACHE_FLAG.varname)) + if (cacheMode == CacheType.MEMORY) { + val logMessage = "Reloading %s.%s into memory.".format(databaseName, tableName) + if (console.isDefined) { + console.get.printInfo(logMessage) + } else { + logInfo(logMessage) + } + val cmd = QueryRewriteUtils.cacheToAlterTable("CACHE %s".format(tableName)) + cmdRunner(cmd) + } + } + db.setCurrentDatabase(currentDbName) + } +} diff --git a/src/main/scala/shark/memstore2/column/ColumnBuilder.scala b/src/main/scala/shark/memstore2/column/ColumnBuilder.scala index 375ec244..84988be3 100644 --- a/src/main/scala/shark/memstore2/column/ColumnBuilder.scala +++ b/src/main/scala/shark/memstore2/column/ColumnBuilder.scala @@ -61,12 +61,12 @@ trait ColumnBuilder[T] { _buffer.order(ByteOrder.nativeOrder()) _buffer.putInt(t.typeID) } - + protected def growIfNeeded(orig: ByteBuffer, size: Int): ByteBuffer = { val capacity = orig.capacity() if (orig.remaining() < size) { - //grow in steps of initial size - var additionalSize = capacity/8 + 1 + // grow in steps of initial size + val additionalSize = capacity / 8 + 1 var newSize = capacity + additionalSize if (additionalSize < size) { newSize = capacity + size @@ -82,7 +82,7 @@ trait ColumnBuilder[T] { } } -class DefaultColumnBuilder[T](val stats: ColumnStats[T], val t: ColumnType[T, _]) +class DefaultColumnBuilder[T](val stats: ColumnStats[T], val t: ColumnType[T, _]) extends CompressedColumnBuilder[T] with NullableColumnBuilder[T]{} @@ -105,7 +105,7 @@ trait CompressedColumnBuilder[T] extends ColumnBuilder[T] { override def build() = { val b = super.build() - + if (compressionSchemes.isEmpty) { new NoCompression().compress(b, t) } else { @@ -136,16 +136,16 @@ object ColumnBuilder { case PrimitiveCategory.BYTE => new ByteColumnBuilder case PrimitiveCategory.TIMESTAMP => new TimestampColumnBuilder case PrimitiveCategory.BINARY => new BinaryColumnBuilder - + // TODO: add decimal column. - case _ => throw new Exception( + case _ => throw new MemoryStoreException( "Invalid primitive object inspector category" + columnOi.getCategory) } } case _ => new GenericColumnBuilder(columnOi) } if (shouldCompress) { - v.compressionSchemes = Seq(new RLE()) + v.compressionSchemes = Seq(new RLE(), new BooleanBitSetCompression()) } v } diff --git a/src/main/scala/shark/memstore2/column/ColumnBuilders.scala b/src/main/scala/shark/memstore2/column/ColumnBuilders.scala index 593f8685..6cee1359 100644 --- a/src/main/scala/shark/memstore2/column/ColumnBuilders.scala +++ b/src/main/scala/shark/memstore2/column/ColumnBuilders.scala @@ -1,3 +1,20 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column import java.nio.ByteBuffer @@ -12,18 +29,6 @@ import shark.execution.serialization.KryoSerializer import shark.memstore2.column.ColumnStats._ -class GenericColumnBuilder(oi: ObjectInspector) - extends DefaultColumnBuilder[ByteStream.Output](new NoOpStats(), GENERIC) { - - override def initialize(initialSize: Int):ByteBuffer = { - val buffer = super.initialize(initialSize) - val objectInspectorSerialized = KryoSerializer.serialize(oi) - buffer.putInt(objectInspectorSerialized.size) - buffer.put(objectInspectorSerialized) - buffer - } -} - class BooleanColumnBuilder extends DefaultColumnBuilder[Boolean](new BooleanColumnStats(), BOOLEAN) class IntColumnBuilder extends DefaultColumnBuilder[Int](new IntColumnStats(), INT) @@ -45,4 +50,20 @@ class TimestampColumnBuilder class BinaryColumnBuilder extends DefaultColumnBuilder[BytesWritable](new NoOpStats(), BINARY) -class VoidColumnBuilder extends DefaultColumnBuilder[Void](new NoOpStats(), VOID) \ No newline at end of file +class VoidColumnBuilder extends DefaultColumnBuilder[Void](new NoOpStats(), VOID) + +/** + * Generic columns that we can serialize, including maps, structs, and other complex types. + */ +class GenericColumnBuilder(oi: ObjectInspector) + extends DefaultColumnBuilder[ByteStream.Output](new NoOpStats(), GENERIC) { + + // Complex data types cannot be null. Override the initialize in NullableColumnBuilder. + override def initialize(initialSize: Int): ByteBuffer = { + val buffer = super.initialize(initialSize) + val objectInspectorSerialized = KryoSerializer.serialize(oi) + buffer.putInt(objectInspectorSerialized.size) + buffer.put(objectInspectorSerialized) + buffer + } +} diff --git a/src/main/scala/shark/memstore2/column/ColumnIterator.scala b/src/main/scala/shark/memstore2/column/ColumnIterator.scala index 5c9b267c..404e456b 100644 --- a/src/main/scala/shark/memstore2/column/ColumnIterator.scala +++ b/src/main/scala/shark/memstore2/column/ColumnIterator.scala @@ -17,28 +17,30 @@ package shark.memstore2.column -import java.nio.ByteBuffer -import java.nio.ByteOrder +import scala.language.implicitConversions +import java.nio.{ByteBuffer, ByteOrder} trait ColumnIterator { - private var _initialized = false - + init() + def init() {} - def next() { - if (!_initialized) { - init() - _initialized = true - } - computeNext() - } + /** + * Produces the next element of this iterator. + */ + def next() - def computeNext(): Unit + /** + * Tests whether this iterator can provide another element. + */ + def hasNext: Boolean - // Should be implemented as a read-only operation by the ColumnIterator - // Can be called any number of times + /** + * Return the current element. The operation should have no side-effect, i.e. it can be invoked + * multiple times returning the same value. + */ def current: Object } @@ -49,25 +51,27 @@ abstract class DefaultColumnIterator[T, V](val buffer: ByteBuffer, val columnTyp object Implicits { implicit def intToCompressionType(i: Int): CompressionType = i match { - case -1 => DefaultCompressionType - case 0 => RLECompressionType - case 1 => DictionaryCompressionType - case _ => throw new UnsupportedOperationException("Compression Type " + i) + case DefaultCompressionType.typeID => DefaultCompressionType + case RLECompressionType.typeID => RLECompressionType + case DictionaryCompressionType.typeID => DictionaryCompressionType + case BooleanBitSetCompressionType.typeID => BooleanBitSetCompressionType + case _ => throw new MemoryStoreException("Unknown compression type " + i) } implicit def intToColumnType(i: Int): ColumnType[_, _] = i match { - case 0 => INT - case 1 => LONG - case 2 => FLOAT - case 3 => DOUBLE - case 4 => BOOLEAN - case 5 => BYTE - case 6 => SHORT - case 7 => VOID - case 8 => STRING - case 9 => TIMESTAMP - case 10 => BINARY - case 11 => GENERIC + case INT.typeID => INT + case LONG.typeID => LONG + case FLOAT.typeID => FLOAT + case DOUBLE.typeID => DOUBLE + case BOOLEAN.typeID => BOOLEAN + case BYTE.typeID => BYTE + case SHORT.typeID => SHORT + case VOID.typeID => VOID + case STRING.typeID => STRING + case TIMESTAMP.typeID => TIMESTAMP + case BINARY.typeID => BINARY + case GENERIC.typeID => GENERIC + case _ => throw new MemoryStoreException("Unknown column type " + i) } } @@ -76,9 +80,14 @@ object ColumnIterator { import shark.memstore2.column.Implicits._ def newIterator(b: ByteBuffer): ColumnIterator = { + new NullableColumnIterator(b.duplicate().order(ByteOrder.nativeOrder())) + } + + def newNonNullIterator(b: ByteBuffer): ColumnIterator = { + // The first 4 bytes in the buffer indicates the column type. val buffer = b.duplicate().order(ByteOrder.nativeOrder()) val columnType: ColumnType[_, _] = buffer.getInt() - val v = columnType match { + columnType match { case INT => new IntColumnIterator(buffer) case LONG => new LongColumnIterator(buffer) case FLOAT => new FloatColumnIterator(buffer) @@ -92,6 +101,5 @@ object ColumnIterator { case TIMESTAMP => new TimestampColumnIterator(buffer) case GENERIC => new GenericColumnIterator(buffer) } - new NullableColumnIterator(v, buffer) } } diff --git a/src/main/scala/shark/memstore2/column/ColumnIterators.scala b/src/main/scala/shark/memstore2/column/ColumnIterators.scala index be9902b5..3060b5b1 100644 --- a/src/main/scala/shark/memstore2/column/ColumnIterators.scala +++ b/src/main/scala/shark/memstore2/column/ColumnIterators.scala @@ -1,3 +1,20 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column import java.nio.ByteBuffer @@ -32,9 +49,9 @@ class BinaryColumnIterator(buffer: ByteBuffer) extends DefaultColumnIterator(buf class StringColumnIterator(buffer: ByteBuffer) extends DefaultColumnIterator(buffer, STRING) class GenericColumnIterator(buffer: ByteBuffer) extends DefaultColumnIterator(buffer, GENERIC) { - + private var _obj: LazyObject[_] = _ - + override def init() { super.init() val oiSize = buffer.getInt() @@ -43,7 +60,7 @@ class GenericColumnIterator(buffer: ByteBuffer) extends DefaultColumnIterator(bu val oi = KryoSerializer.deserialize[ObjectInspector](oiSerialized) _obj = LazyFactory.createLazyObject(oi) } - + override def current = { val v = super.current.asInstanceOf[ByteArrayRef] _obj.init(v, 0, v.getData().length) diff --git a/src/main/scala/shark/memstore2/column/ColumnStats.scala b/src/main/scala/shark/memstore2/column/ColumnStats.scala index 31270fa3..dce811d5 100644 --- a/src/main/scala/shark/memstore2/column/ColumnStats.scala +++ b/src/main/scala/shark/memstore2/column/ColumnStats.scala @@ -25,7 +25,8 @@ import org.apache.hadoop.io.Text /** - * Column level statistics, including range (min, max). + * Column level statistics, including range (min, max). We expect null values to be taken care + * of outside of the ColumnStats, so none of these stats should take null values. */ sealed trait ColumnStats[@specialized(Boolean, Byte, Short, Int, Long, Float, Double) T] extends Serializable { @@ -35,7 +36,6 @@ sealed trait ColumnStats[@specialized(Boolean, Byte, Short, Int, Long, Float, Do protected def _min: T protected def _max: T - def min: T = _min def max: T = _max @@ -67,27 +67,29 @@ object ColumnStats { class BooleanColumnStats extends ColumnStats[Boolean] { protected var _max = false protected var _min = true + override def append(v: Boolean) { if (v) _max = v else _min = v } + def :=(v: Any): Boolean = { v match { - case u:Boolean => _min <= u && _max >= u + case u: Boolean => _min <= u && _max >= u case _ => true } } def :>(v: Any): Boolean = { v match { - case u:Boolean => _max > u + case u: Boolean => _max > u case _ => true } } def :<(v: Any): Boolean = { v match { - case u:Boolean => _min < u + case u: Boolean => _min < u case _ => true } } @@ -97,6 +99,7 @@ object ColumnStats { class ByteColumnStats extends ColumnStats[Byte] { protected var _max = Byte.MinValue protected var _min = Byte.MaxValue + override def append(v: Byte) { if (v > _max) _max = v if (v < _min) _min = v @@ -104,21 +107,21 @@ object ColumnStats { def :=(v: Any): Boolean = { v match { - case u:Byte => _min <= u && _max >= u + case u: Byte => _min <= u && _max >= u case _ => true } } def :>(v: Any): Boolean = { v match { - case u:Byte => _max > u + case u: Byte => _max > u case _ => true } } def :<(v: Any): Boolean = { v match { - case u:Byte => _min < u + case u: Byte => _min < u case _ => true } } @@ -127,27 +130,29 @@ object ColumnStats { class ShortColumnStats extends ColumnStats[Short] { protected var _max = Short.MinValue protected var _min = Short.MaxValue + override def append(v: Short) { if (v > _max) _max = v if (v < _min) _min = v } + def :=(v: Any): Boolean = { v match { - case u:Short => _min <= u && _max >= u + case u: Short => _min <= u && _max >= u case _ => true } } def :>(v: Any): Boolean = { v match { - case u:Short => _max > u + case u: Short => _max > u case _ => true } } def :<(v: Any): Boolean = { v match { - case u:Short => _min < u + case u: Short => _min < u case _ => true } } @@ -184,14 +189,14 @@ object ColumnStats { def :>(v: Any): Boolean = { v match { - case u:Int => _max > u + case u: Int => _max > u case _ => true } } def :<(v: Any): Boolean = { v match { - case u:Int => _min < u + case u: Int => _min < u case _ => true } } @@ -228,27 +233,29 @@ object ColumnStats { class LongColumnStats extends ColumnStats[Long] { protected var _max = Long.MinValue protected var _min = Long.MaxValue + override def append(v: Long) { if (v > _max) _max = v if (v < _min) _min = v } + def :=(v: Any): Boolean = { v match { - case u:Long => _min <= u && _max >= u + case u: Long => _min <= u && _max >= u case _ => true } } def :>(v: Any): Boolean = { v match { - case u:Long => _max > u + case u: Long => _max > u case _ => true } } def :<(v: Any): Boolean = { v match { - case u:Long => _min < u + case u: Long => _min < u case _ => true } } @@ -257,20 +264,22 @@ object ColumnStats { class FloatColumnStats extends ColumnStats[Float] { protected var _max = Float.MinValue protected var _min = Float.MaxValue + override def append(v: Float) { if (v > _max) _max = v if (v < _min) _min = v } + def :=(v: Any): Boolean = { v match { - case u:Float => _min <= u && _max >= u + case u: Float => _min <= u && _max >= u case _ => true } } def :>(v: Any): Boolean = { v match { - case u:Float => _max > u + case u: Float => _max > u case _ => true } } @@ -286,10 +295,12 @@ object ColumnStats { class DoubleColumnStats extends ColumnStats[Double] { protected var _max = Double.MinValue protected var _min = Double.MaxValue + override def append(v: Double) { if (v > _max) _max = v if (v < _min) _min = v } + def :=(v: Any): Boolean = { v match { case u:Double => _min <= u && _max >= u @@ -315,10 +326,12 @@ object ColumnStats { class TimestampColumnStats extends ColumnStats[Timestamp] { protected var _max = new Timestamp(0) protected var _min = new Timestamp(Long.MaxValue) + override def append(v: Timestamp) { if (v.compareTo(_max) > 0) _max = v if (v.compareTo(_min) < 0) _min = v } + def :=(v: Any): Boolean = { v match { case u: Timestamp => _min.compareTo(u) <=0 && _max.compareTo(u) >= 0 @@ -345,8 +358,12 @@ object ColumnStats { // Note: this is not Java serializable because Text is not Java serializable. protected var _max: Text = null protected var _min: Text = null - + def :=(v: Any): Boolean = { + if (_max eq null) { + // This partition doesn't contain any non-null strings in this column. Return false. + return false + } v match { case u: Text => _min.compareTo(u) <= 0 && _max.compareTo(u) >= 0 case u: String => this := new Text(u) @@ -355,6 +372,10 @@ object ColumnStats { } def :>(v: Any): Boolean = { + if (_max eq null) { + // This partition doesn't contain any non-null strings in this column. Return false. + return false + } v match { case u: Text => _max.compareTo(u) > 0 case u: String => this :> new Text(u) @@ -363,14 +384,19 @@ object ColumnStats { } def :<(v: Any): Boolean = { + if (_max eq null) { + // This partition doesn't contain any non-null strings in this column. Return false. + return false + } v match { - case u:Text => _min.compareTo(u) < 0 + case u: Text => _min.compareTo(u) < 0 case u: String => this :< new Text(u) case _ => true } } override def append(v: Text) { + assert(!(v eq null)) // Need to make a copy of Text since Text is not immutable and we reuse // the same Text object in serializer to mitigate frequent GC. if (_max == null) { @@ -382,7 +408,7 @@ object ColumnStats { _min = new Text(v) } else if (v.compareTo(_min) < 0) { _min.set(v.getBytes(), 0, v.getLength()) - } + } } override def readExternal(in: ObjectInput) { diff --git a/src/main/scala/shark/memstore2/column/ColumnType.scala b/src/main/scala/shark/memstore2/column/ColumnType.scala index 068efe42..4ca62a19 100644 --- a/src/main/scala/shark/memstore2/column/ColumnType.scala +++ b/src/main/scala/shark/memstore2/column/ColumnType.scala @@ -20,39 +20,85 @@ package shark.memstore2.column import java.nio.ByteBuffer import java.sql.Timestamp +import scala.reflect.ClassTag + import org.apache.hadoop.hive.serde2.ByteStream import org.apache.hadoop.hive.serde2.`lazy`.{ByteArrayRef, LazyBinary} -import org.apache.hadoop.hive.serde2.io.{TimestampWritable, ShortWritable, ByteWritable, DoubleWritable} +import org.apache.hadoop.hive.serde2.io.ByteWritable +import org.apache.hadoop.hive.serde2.io.DoubleWritable +import org.apache.hadoop.hive.serde2.io.ShortWritable +import org.apache.hadoop.hive.serde2.io.TimestampWritable import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io._ -abstract class ColumnType[T, V](val typeID: Int, val defaultSize: Int) { - - def extract(currentPos: Int, buffer: ByteBuffer): T - +/** + * @param typeID A unique ID representing the type. + * @param defaultSize Default size in bytes for one element of type T (e.g. Int = 4). + * @tparam T Scala data type for the column. + * @tparam V Writable data type for the column. + */ +sealed abstract class ColumnType[T : ClassTag, V : ClassTag]( + val typeID: Int, val defaultSize: Int) { + + /** + * Scala ClassTag. Can be used to create primitive arrays and hash tables. + */ + def scalaTag = implicitly[ClassTag[T]] + + /** + * Scala ClassTag. Can be used to create primitive arrays and hash tables. + */ + def writableScalaTag = implicitly[ClassTag[V]] + + /** + * Extract a value out of the buffer at the buffer's current position. + */ + def extract(buffer: ByteBuffer): T + + /** + * Append the given value v of type T into the given ByteBuffer. + */ def append(v: T, buffer: ByteBuffer) + /** + * Return the Scala data representation of the given object, using an object inspector. + */ def get(o: Object, oi: ObjectInspector): T - def actualSize(v: T) = defaultSize - - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: V) - + /** + * Return the size of the value. This is used to calculate the size of variable length types + * such as byte arrays and strings. + */ + def actualSize(v: T): Int = defaultSize + + /** + * Extract a value out of the buffer at the buffer's current position, and put it in the writable + * object. This is used as an optimization to reduce the temporary objects created, since the + * writable object can be reused. + */ + def extractInto(buffer: ByteBuffer, writable: V) + + /** + * Create a new writable object corresponding to this type. + */ def newWritable(): V + /** + * Create a duplicated copy of the value. + */ def clone(v: T): T = v } object INT extends ColumnType[Int, IntWritable](0, 4) { - override def append(v: Int, buffer: ByteBuffer) = { + override def append(v: Int, buffer: ByteBuffer) { buffer.putInt(v) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { buffer.getInt() } @@ -60,8 +106,8 @@ object INT extends ColumnType[Int, IntWritable](0, 4) { oi.asInstanceOf[IntObjectInspector].get(o) } - override def extractInto(currentPos: Int, buffer: ByteBuffer, writable: IntWritable) = { - writable.set(extract(currentPos, buffer)) + override def extractInto(buffer: ByteBuffer, writable: IntWritable) { + writable.set(extract(buffer)) } override def newWritable() = new IntWritable @@ -70,11 +116,11 @@ object INT extends ColumnType[Int, IntWritable](0, 4) { object LONG extends ColumnType[Long, LongWritable](1, 8) { - override def append(v: Long, buffer: ByteBuffer) = { + override def append(v: Long, buffer: ByteBuffer) { buffer.putLong(v) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { buffer.getLong() } @@ -82,21 +128,21 @@ object LONG extends ColumnType[Long, LongWritable](1, 8) { oi.asInstanceOf[LongObjectInspector].get(o) } - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: LongWritable) = { - writable.set(extract(currentPos, buffer)) + override def extractInto(buffer: ByteBuffer, writable: LongWritable) { + writable.set(extract(buffer)) } - def newWritable() = new LongWritable + override def newWritable() = new LongWritable } object FLOAT extends ColumnType[Float, FloatWritable](2, 4) { - override def append(v: Float, buffer: ByteBuffer) = { + override def append(v: Float, buffer: ByteBuffer) { buffer.putFloat(v) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { buffer.getFloat() } @@ -104,42 +150,43 @@ object FLOAT extends ColumnType[Float, FloatWritable](2, 4) { oi.asInstanceOf[FloatObjectInspector].get(o) } - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: FloatWritable) = { - writable.set(extract(currentPos, buffer)) + override def extractInto(buffer: ByteBuffer, writable: FloatWritable) { + writable.set(extract(buffer)) } - def newWritable() = new FloatWritable + override def newWritable() = new FloatWritable } object DOUBLE extends ColumnType[Double, DoubleWritable](3, 8) { - override def append(v: Double, buffer: ByteBuffer) = { + override def append(v: Double, buffer: ByteBuffer) { buffer.putDouble(v) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { buffer.getDouble() } + override def get(o: Object, oi: ObjectInspector): Double = { oi.asInstanceOf[DoubleObjectInspector].get(o) } - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: DoubleWritable) = { - writable.set(extract(currentPos, buffer)) + override def extractInto(buffer: ByteBuffer, writable: DoubleWritable) { + writable.set(extract(buffer)) } - def newWritable() = new DoubleWritable + override def newWritable() = new DoubleWritable } object BOOLEAN extends ColumnType[Boolean, BooleanWritable](4, 1) { - override def append(v: Boolean, buffer: ByteBuffer) = { + override def append(v: Boolean, buffer: ByteBuffer) { buffer.put(if (v) 1.toByte else 0.toByte) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { if (buffer.get() == 1) true else false } @@ -147,42 +194,42 @@ object BOOLEAN extends ColumnType[Boolean, BooleanWritable](4, 1) { oi.asInstanceOf[BooleanObjectInspector].get(o) } - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: BooleanWritable) = { - writable.set(extract(currentPos, buffer)) + override def extractInto(buffer: ByteBuffer, writable: BooleanWritable) { + writable.set(extract(buffer)) } - def newWritable() = new BooleanWritable + override def newWritable() = new BooleanWritable } object BYTE extends ColumnType[Byte, ByteWritable](5, 1) { - override def append(v: Byte, buffer: ByteBuffer) = { + override def append(v: Byte, buffer: ByteBuffer) { buffer.put(v) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { buffer.get() } override def get(o: Object, oi: ObjectInspector): Byte = { oi.asInstanceOf[ByteObjectInspector].get(o) } - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: ByteWritable) = { - writable.set(extract(currentPos, buffer)) + override def extractInto(buffer: ByteBuffer, writable: ByteWritable) { + writable.set(extract(buffer)) } - def newWritable() = new ByteWritable + override def newWritable() = new ByteWritable } object SHORT extends ColumnType[Short, ShortWritable](6, 2) { - override def append(v: Short, buffer: ByteBuffer) = { + override def append(v: Short, buffer: ByteBuffer) { buffer.putShort(v) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { buffer.getShort() } @@ -190,8 +237,8 @@ object SHORT extends ColumnType[Short, ShortWritable](6, 2) { oi.asInstanceOf[ShortObjectInspector].get(o) } - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: ShortWritable) = { - writable.set(extract(currentPos, buffer)) + def extractInto(buffer: ByteBuffer, writable: ShortWritable) { + writable.set(extract(buffer)) } def newWritable() = new ShortWritable @@ -200,15 +247,15 @@ object SHORT extends ColumnType[Short, ShortWritable](6, 2) { object VOID extends ColumnType[Void, NullWritable](7, 0) { - override def append(v: Void, buffer: ByteBuffer) = {} + override def append(v: Void, buffer: ByteBuffer) {} - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { throw new UnsupportedOperationException() } override def get(o: Object, oi: ObjectInspector) = null - override def extractInto(currentPos: Int, buffer: ByteBuffer, writable: NullWritable) = {} + override def extractInto(buffer: ByteBuffer, writable: NullWritable) {} override def newWritable() = NullWritable.get } @@ -234,9 +281,9 @@ object STRING extends ColumnType[Text, Text](8, 8) { buffer.put(v.getBytes(), 0, length) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { val t = new Text() - extractInto(currentPos, buffer, t) + extractInto(buffer, t) t } @@ -246,7 +293,7 @@ object STRING extends ColumnType[Text, Text](8, 8) { override def actualSize(v: Text) = v.getLength() + 4 - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: Text) = { + override def extractInto(buffer: ByteBuffer, writable: Text) { val length = buffer.getInt() var b = _bytesFld.get(writable).asInstanceOf[Array[Byte]] if (b == null || b.length < length) { @@ -257,7 +304,8 @@ object STRING extends ColumnType[Text, Text](8, 8) { _lengthFld.set(writable, length) } - def newWritable() = new Text + override def newWritable() = new Text + override def clone(v: Text) = { val t = new Text() t.set(v) @@ -268,12 +316,12 @@ object STRING extends ColumnType[Text, Text](8, 8) { object TIMESTAMP extends ColumnType[Timestamp, TimestampWritable](9, 12) { - override def append(v: Timestamp, buffer: ByteBuffer) = { + override def append(v: Timestamp, buffer: ByteBuffer) { buffer.putLong(v.getTime()) buffer.putInt(v.getNanos()) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { val ts = new Timestamp(0) ts.setTime(buffer.getLong()) ts.setNanos(buffer.getInt()) @@ -284,11 +332,11 @@ object TIMESTAMP extends ColumnType[Timestamp, TimestampWritable](9, 12) { oi.asInstanceOf[TimestampObjectInspector].getPrimitiveJavaObject(o) } - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: TimestampWritable) = { - writable.set(extract(currentPos, buffer)) + override def extractInto(buffer: ByteBuffer, writable: TimestampWritable) { + writable.set(extract(buffer)) } - def newWritable() = new TimestampWritable + override def newWritable() = new TimestampWritable } @@ -306,13 +354,13 @@ object BINARY extends ColumnType[BytesWritable, BytesWritable](10, 16) { f } - override def append(v: BytesWritable, buffer: ByteBuffer) = { + override def append(v: BytesWritable, buffer: ByteBuffer) { val length = v.getLength() buffer.putInt(length) buffer.put(v.getBytes(), 0, length) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { throw new UnsupportedOperationException() } @@ -324,7 +372,7 @@ object BINARY extends ColumnType[BytesWritable, BytesWritable](10, 16) { } } - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: BytesWritable) = { + override def extractInto(buffer: ByteBuffer, writable: BytesWritable) { val length = buffer.getInt() var b = _bytesFld.get(writable).asInstanceOf[Array[Byte]] if (b == null || b.length < length) { @@ -335,7 +383,7 @@ object BINARY extends ColumnType[BytesWritable, BytesWritable](10, 16) { _lengthFld.set(writable, length) } - def newWritable() = new BytesWritable + override def newWritable() = new BytesWritable override def actualSize(v: BytesWritable) = v.getLength() + 4 } @@ -349,7 +397,7 @@ object GENERIC extends ColumnType[ByteStream.Output, ByteArrayRef](11, 16) { buffer.put(v.getData(), 0, length) } - override def extract(currentPos: Int, buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer) = { throw new UnsupportedOperationException() } @@ -357,12 +405,14 @@ object GENERIC extends ColumnType[ByteStream.Output, ByteArrayRef](11, 16) { o.asInstanceOf[ByteStream.Output] } - def extractInto(currentPos: Int, buffer: ByteBuffer, writable: ByteArrayRef) = { + override def extractInto(buffer: ByteBuffer, writable: ByteArrayRef) { val length = buffer.getInt() val a = new Array[Byte](length) buffer.get(a, 0, length) writable.setData(a) } - def newWritable() = new ByteArrayRef + override def newWritable() = new ByteArrayRef + + override def actualSize(v: ByteStream.Output): Int = v.getCount() + 4 } diff --git a/src/main/scala/shark/memstore2/column/CompressedColumnIterator.scala b/src/main/scala/shark/memstore2/column/CompressedColumnIterator.scala index 7b4e5ab8..5d74a61c 100644 --- a/src/main/scala/shark/memstore2/column/CompressedColumnIterator.scala +++ b/src/main/scala/shark/memstore2/column/CompressedColumnIterator.scala @@ -1,8 +1,25 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column import java.nio.ByteBuffer -import scala.collection.mutable.{Map, HashMap} +import org.apache.hadoop.io.BooleanWritable import shark.memstore2.column.Implicits._ @@ -11,9 +28,8 @@ import shark.memstore2.column.Implicits._ * The first element of the buffer at the point of initialization * is expected to be the type of compression indicator. */ -trait CompressedColumnIterator extends ColumnIterator{ +trait CompressedColumnIterator extends ColumnIterator { - private var _compressionType: CompressionType = _ private var _decoder: Iterator[_] = _ private var _current: Any = _ @@ -22,21 +38,25 @@ trait CompressedColumnIterator extends ColumnIterator{ def columnType: ColumnType[_,_] override def init() { - _compressionType = buffer.getInt() - _decoder = _compressionType match { + val compressionType: CompressionType = buffer.getInt() + _decoder = compressionType match { case DefaultCompressionType => new DefaultDecoder(buffer, columnType) case RLECompressionType => new RLDecoder(buffer, columnType) case DictionaryCompressionType => new DictDecoder(buffer, columnType) + case BooleanBitSetCompressionType => new BooleanBitSetDecoder(buffer, columnType) case _ => throw new UnsupportedOperationException() } } - override def computeNext() { + override def next() { + // TODO: can we remove the if branch? if (_decoder.hasNext) { _current = _decoder.next() } } - + + override def hasNext = _decoder.hasNext + override def current = _current.asInstanceOf[Object] } @@ -50,7 +70,7 @@ class DefaultDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extend override def hasNext = buffer.hasRemaining() override def next(): V = { - columnType.extractInto(buffer.position(), buffer, _current) + columnType.extractInto(buffer, _current) _current } } @@ -59,17 +79,17 @@ class DefaultDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extend * Run Length Decoder, decodes data compressed in RLE format of [element, length] */ class RLDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extends Iterator[V] { - + private var _run: Int = _ private var _count: Int = 0 private val _current: V = columnType.newWritable() override def hasNext = buffer.hasRemaining() - + override def next(): V = { if (_count == _run) { //next run - columnType.extractInto(buffer.position(), buffer, _current) + columnType.extractInto(buffer, _current) _run = buffer.getInt() _count = 1 } else { @@ -79,26 +99,62 @@ class RLDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extends Ite } } -class DictDecoder[V] (buffer:ByteBuffer, columnType: ColumnType[_, V]) extends Iterator[V] { +/** + * Dictionary encoding compression. + */ +class DictDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extends Iterator[V] { - private val _dictionary: Map[Int, V] = { + // Dictionary in the form of an array. The index is the encoded value, and the value is the + // decompressed value. + private val _dictionary: Array[V] = { val size = buffer.getInt() - val d = new HashMap[Int, V]() + val arr = columnType.writableScalaTag.newArray(size) var count = 0 while (count < size) { - //read text, followed by index - val text = columnType.extract(buffer.position(), buffer) - val index = buffer.getInt() - d.put(index, text.asInstanceOf[V]) - count+= 1 + val writable = columnType.newWritable() + columnType.extractInto(buffer, writable) + arr(count) = writable.asInstanceOf[V] + count += 1 } - d + arr } override def hasNext = buffer.hasRemaining() - + + override def next(): V = { + val index = buffer.getShort().toInt + _dictionary(index) + } +} + +/** + * Boolean BitSet encoding. + */ +class BooleanBitSetDecoder[V]( + buffer: ByteBuffer, + columnType: ColumnType[_, V], + var _pos: Int, + var _uncompressedSize: Int, + var _curValue: Long, + var _writable: BooleanWritable + ) extends Iterator[V] { + + def this(buffer: ByteBuffer, columnType: ColumnType[_, V]) + = this(buffer, columnType, 0, buffer.getInt(), 0, new BooleanWritable()) + + override def hasNext = _pos < _uncompressedSize + override def next(): V = { - val index = buffer.getInt() - _dictionary.get(index).get + val offset = _pos % BooleanBitSetCompression.BOOLEANS_PER_LONG + + if (offset == 0) { + _curValue = buffer.getLong() + } + + val retval: Boolean = (_curValue & (1 << offset)) != 0 + _pos += 1 + _writable.set(retval) + _writable.asInstanceOf[V] } } + diff --git a/src/main/scala/shark/memstore2/column/CompressionAlgorithm.scala b/src/main/scala/shark/memstore2/column/CompressionAlgorithm.scala index a26d2ff5..5db74dee 100644 --- a/src/main/scala/shark/memstore2/column/CompressionAlgorithm.scala +++ b/src/main/scala/shark/memstore2/column/CompressionAlgorithm.scala @@ -1,9 +1,26 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column -import java.nio.ByteBuffer -import java.nio.ByteOrder +import java.nio.{ByteBuffer, ByteOrder} + import scala.annotation.tailrec -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ArrayBuffer, HashMap} /** * API for Compression @@ -12,15 +29,35 @@ trait CompressionAlgorithm { def compressionType: CompressionType + /** + * Tests whether the compression algorithm supports a specific column type. + */ def supportsType(t: ColumnType[_, _]): Boolean + /** + * Collect a value so we can update the compression ratio for this compression algorithm. + */ def gatherStatsForCompressibility[T](v: T, t: ColumnType[T, _]) /** * Return compression ratio between 0 and 1, smaller score imply higher compressibility. + * This is used to pick the compression algorithm to apply at runtime. + */ + def compressionRatio: Double = compressedSize.toDouble / uncompressedSize.toDouble + + /** + * The uncompressed size of the input data. */ - def compressionRatio: Double + def uncompressedSize: Int + /** + * Estimation of the data size once compressed. + */ + def compressedSize: Int + + /** + * Compress the given buffer and return the compressed data as a new buffer. + */ def compress[T](b: ByteBuffer, t: ColumnType[T, _]): ByteBuffer } @@ -30,19 +67,28 @@ case class CompressionType(typeID: Int) object DefaultCompressionType extends CompressionType(-1) object RLECompressionType extends CompressionType(0) + object DictionaryCompressionType extends CompressionType(1) -object RLEVariantCompressionType extends CompressionType(2) +object BooleanBitSetCompressionType extends CompressionType(2) +/** + * An no-op compression. + */ class NoCompression extends CompressionAlgorithm { + override def compressionType = DefaultCompressionType override def supportsType(t: ColumnType[_,_]) = true - override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) = {} + override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) {} override def compressionRatio: Double = 1.0 + override def uncompressedSize: Int = 0 + + override def compressedSize: Int = 0 + override def compress[T](b: ByteBuffer, t: ColumnType[T, _]) = { val len = b.limit() val newBuffer = ByteBuffer.allocate(len + 4) @@ -57,52 +103,56 @@ class NoCompression extends CompressionAlgorithm { } /** - * Implements Run Length Encoding + * Run-length encoding for columns with a lot of repeated values. */ class RLE extends CompressionAlgorithm { - private var _total: Int = 0 + private var _uncompressedSize: Int = 0 + private var _compressedSize: Int = 0 + + // Previous element, used to track how many runs and the run lengths. private var _prev: Any = _ + // Current run length. private var _run: Int = 0 - private var _size: Int = 0 override def compressionType = RLECompressionType override def supportsType(t: ColumnType[_, _]) = { t match { - case INT | STRING | SHORT | BYTE | BOOLEAN => true + case LONG | INT | STRING | SHORT | BYTE | BOOLEAN => true case _ => false } } - override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) = { + override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) { val s = t.actualSize(v) if (_prev == null) { + // This is the very first run. _prev = t.clone(v) _run = 1 + _compressedSize += s + 4 } else { if (_prev.equals(v)) { + // Add one to the current run's length. _run += 1 } else { - // flush run into size - _size += (t.actualSize(_prev.asInstanceOf[T]) + 4) + // Start a new run. Update the current run length. + _compressedSize += s + 4 _prev = t.clone(v) _run = 1 } } - _total += s + _uncompressedSize += s } + override def uncompressedSize: Int = _uncompressedSize + // Note that we don't actually track the size of the last run into account to simplify the // logic a little bit. - override def compressionRatio = _size / (_total + 0.0) + override def compressedSize: Int = _compressedSize - override def compress[T](b: ByteBuffer, t: ColumnType[T,_]) = { - // Add the size of the last run to the _size - if (_prev != null) { - _size += t.actualSize(_prev.asInstanceOf[T]) + 4 - } - - val compressedBuffer = ByteBuffer.allocate(_size + 4 + 4) + override def compress[T](b: ByteBuffer, t: ColumnType[T,_]): ByteBuffer = { + // Leave 4 extra bytes for column type and another 4 for compression type. + val compressedBuffer = ByteBuffer.allocate(4 + 4 + _compressedSize) compressedBuffer.order(ByteOrder.nativeOrder()) compressedBuffer.putInt(b.getInt()) compressedBuffer.putInt(compressionType.typeID) @@ -112,7 +162,7 @@ class RLE extends CompressionAlgorithm { } @tailrec private final def encode[T](currentBuffer: ByteBuffer, - compressedBuffer: ByteBuffer, currentRun: (T, Int), t: ColumnType[T,_]) { + compressedBuffer: ByteBuffer, currentRun: (T, Int), t: ColumnType[T,_]) { def writeOutRun() { t.append(currentRun._1, compressedBuffer) compressedBuffer.putInt(currentRun._2) @@ -121,7 +171,7 @@ class RLE extends CompressionAlgorithm { writeOutRun() return } - val elem = t.extract(currentBuffer.position(), currentBuffer) + val elem = t.extract(currentBuffer) val newRun = if (currentRun == null) { (elem, 1) @@ -137,88 +187,175 @@ class RLE extends CompressionAlgorithm { } } +/** + * Dictionary encoding for columns with small cardinality. This algorithm encodes values into + * short integers (2 byte each). It can support up to 32k distinct values. + */ class DictionaryEncoding extends CompressionAlgorithm { - private val MAX_DICT_SIZE = 4000 - private val _dictionary = new HashMap[Any, Int]() - private var _dictionarySize = 0 - private var _totalSize = 0 + // 32K unique values allowed + private val MAX_DICT_SIZE = Short.MaxValue - 1 + + // The dictionary that maps a value to the encoded short integer. + private var _dictionary = new HashMap[Any, Short]() + + // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself. + private var _values = new ArrayBuffer[Any](1024) + + // We use a short integer to store the dictionary index, which takes 2 bytes. + private val indexSize = 2 + + // Size of the dictionary, in bytes. Initialize the dictionary size to 4 since we use an int + // to store the number of elements in the dictionary. + private var _dictionarySize = 4 + + // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary + // overflows. + private var _uncompressedSize = 0 + + // Total number of elements. private var _count = 0 - private var _index = 0 + + // If the number of distinct elements is too large, we discard the use of dictionary + // encoding and set the overflow flag to true. private var _overflow = false override def compressionType = DictionaryCompressionType override def supportsType(t: ColumnType[_, _]) = t match { - case STRING => true + case STRING | LONG | INT => true case _ => false } - private def encode[T](v: T, t: ColumnType[T, _], sizeFunc:T => Int): Int = { - _count += 1 - val size = sizeFunc(v) - _totalSize += size - if (_dictionary.size < MAX_DICT_SIZE) { - val s = t.clone(v) - _dictionary.get(s) match { - case Some(index) => index - case None => { - _dictionary.put(s, _index) - _index += 1 - _dictionarySize += (size + 4) - _index + override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T, _]) { + // Use this function to build up a dictionary. + if (!_overflow) { + val size = t.actualSize(v) + _count += 1 + _uncompressedSize += size + + if (!_dictionary.contains(v)) { + // The dictionary doesn't contain the value. Add the value to the dictionary if we haven't + // overflown yet. + if (_dictionary.size < MAX_DICT_SIZE) { + val clone = t.clone(v) + _values.append(clone) + _dictionary.put(clone, _dictionary.size.toShort) + _dictionarySize += size + } else { + // Overflown. Release the dictionary immediately to lower memory pressure. + _overflow = true + _dictionary = null + _values = null } } - } else { - _overflow = true - -1 } } - override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T, _]) = { - //need an estimate of the # of uniques so we can build an appropriate - //dictionary if needed. More precisely, we only need a lower bound - //on # of uniques. - val size = t.actualSize(v) - encode(v, t, { _:T => size}) - } + override def uncompressedSize: Int = _uncompressedSize /** - * return score between 0 and 1, smaller score imply higher compressibility. + * Return the compressed data size if encoded with dictionary encoding. If the dictionary + * cardinality (i.e. the number of distinct elements) is bigger than 32K, we return an + * a really large number. */ - override def compressionRatio: Double = { - if (_overflow) 1.0 else (_count*4 + dictionarySize) / (_totalSize + 0.0) + override def compressedSize: Int = { + // Total compressed size = + // size of the dictionary + + // the number of elements * dictionary encoded size (short) + if (_overflow) Int.MaxValue else _dictionarySize + _count * indexSize } - private def writeDictionary[T](compressedBuffer: ByteBuffer, t: ColumnType[T, _]) { - //store dictionary size + override def compress[T](b: ByteBuffer, t: ColumnType[T, _]): ByteBuffer = { + if (_overflow) { + throw new MemoryStoreException( + "Dictionary encoding should not be used because we have overflown the dictionary.") + } + + // Create a new buffer and store the compression type and column type. + // Leave 4 extra bytes for column type and another 4 for compression type. + val compressedBuffer = ByteBuffer.allocate(4 + 4 + compressedSize) + compressedBuffer.order(ByteOrder.nativeOrder()) + compressedBuffer.putInt(b.getInt()) + compressedBuffer.putInt(compressionType.typeID) + + // Write out the dictionary. compressedBuffer.putInt(_dictionary.size) - //store the dictionary - _dictionary.foreach { x => - t.append(x._1.asInstanceOf[T], compressedBuffer) - compressedBuffer.putInt(x._2) + _values.foreach { v => + t.append(v.asInstanceOf[T], compressedBuffer) } + + // Write out the encoded values, each is represented by a short integer. + while (b.hasRemaining()) { + val v = t.extract(b) + compressedBuffer.putShort(_dictionary(v)) + } + + // Rewind the compressed buffer and return it. + compressedBuffer.rewind() + compressedBuffer } +} - private def dictionarySize = _dictionarySize + 4 +/** +* BitSet compression for Boolean values. +*/ +object BooleanBitSetCompression { + val BOOLEANS_PER_LONG : Short = 64 +} - override def compress[T](b: ByteBuffer, t: ColumnType[T, _]): ByteBuffer = { - //build a dictionary of given size - val compressedBuffer = ByteBuffer.allocate(_count*4 + dictionarySize + 4 + 4) +class BooleanBitSetCompression extends CompressionAlgorithm { + + private var _uncompressedSize = 0 + + override def compressionType = BooleanBitSetCompressionType + + override def supportsType(t: ColumnType[_, _]) = { + t match { + case BOOLEAN => true + case _ => false + } + } + + override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) { + val s = t.actualSize(v) + _uncompressedSize += s + } + + // Booleans are encoded into Longs; in addition, we need one int to store the number of + // Booleans contained in the compressed buffer. + override def compressedSize: Int = { + math.ceil(_uncompressedSize.toFloat / BooleanBitSetCompression.BOOLEANS_PER_LONG).toInt * 8 + 4 + } + + override def uncompressedSize: Int = _uncompressedSize + + override def compress[T](b: ByteBuffer, t: ColumnType[T,_]): ByteBuffer = { + // Leave 4 extra bytes for column type, another 4 for compression type. + val compressedBuffer = ByteBuffer.allocate(4 + 4 + compressedSize) compressedBuffer.order(ByteOrder.nativeOrder()) compressedBuffer.putInt(b.getInt()) compressedBuffer.putInt(compressionType.typeID) - //store dictionary size - writeDictionary(compressedBuffer, t) - //traverse the original buffer - while (b.hasRemaining()) { - val v = t.extract(b.position(), b) - _dictionary.get(v).map { index => - compressedBuffer.putInt(index) + compressedBuffer.putInt(b.remaining()) + + var cur: Long = 0 + var pos: Int = 0 + var offset: Int = 0 + + while (b.hasRemaining) { + offset = pos % BooleanBitSetCompression.BOOLEANS_PER_LONG + val elem = t.extract(b).asInstanceOf[Boolean] + + if (elem) { + cur = (cur | (1 << offset)).toLong } - + if (offset == BooleanBitSetCompression.BOOLEANS_PER_LONG - 1 || !b.hasRemaining) { + compressedBuffer.putLong(cur) + cur = 0 + } + pos += 1 } compressedBuffer.rewind() compressedBuffer } -} \ No newline at end of file +} diff --git a/src/main/scala/shark/memstore2/column/MemoryStoreException.scala b/src/main/scala/shark/memstore2/column/MemoryStoreException.scala new file mode 100644 index 00000000..5db2631d --- /dev/null +++ b/src/main/scala/shark/memstore2/column/MemoryStoreException.scala @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.memstore2.column + + +class MemoryStoreException(message: String) extends Exception(message) diff --git a/src/main/scala/shark/memstore2/column/NullableColumnBuilder.scala b/src/main/scala/shark/memstore2/column/NullableColumnBuilder.scala index 2b544f4e..2d79fd87 100644 --- a/src/main/scala/shark/memstore2/column/NullableColumnBuilder.scala +++ b/src/main/scala/shark/memstore2/column/NullableColumnBuilder.scala @@ -1,3 +1,20 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column import java.nio.ByteBuffer @@ -7,19 +24,20 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector /** - * Builds a nullable column. The byte buffer of a nullable column contains - * the column type, followed by the null count and the index of nulls, followed - * finally by the non nulls. + * Builds a nullable column. The byte buffer of a nullable column contains: + * - 4 bytes for the null count (number of nulls) + * - positions for each null, in ascending order + * - the non-null data (column data type, compression type, data...) */ trait NullableColumnBuilder[T] extends ColumnBuilder[T] { private var _nulls: ByteBuffer = _ - + private var _pos: Int = _ - private var _nullCount:Int = _ + private var _nullCount: Int = _ override def initialize(initialSize: Int): ByteBuffer = { - _nulls = ByteBuffer.allocate(1024) + _nulls = ByteBuffer.allocate(1024) _nulls.order(ByteOrder.nativeOrder()) _pos = 0 _nullCount = 0 @@ -38,19 +56,16 @@ trait NullableColumnBuilder[T] extends ColumnBuilder[T] { } override def build(): ByteBuffer = { - val b = super.build() - if (_pos == 0) { - b - } else { - val v = _nulls.position() - _nulls.limit(v) - _nulls.rewind() - val newBuffer = ByteBuffer.allocate(b.limit + v + 4) - newBuffer.order(ByteOrder.nativeOrder()) - val colType= b.getInt() - newBuffer.putInt(colType).putInt(_nullCount).put(_nulls).put(b) - newBuffer.rewind() - newBuffer - } + val nonNulls = super.build() + val nullDataLen = _nulls.position() + _nulls.limit(nullDataLen) + _nulls.rewind() + + // 4 bytes for null count + null positions + non nulls + val newBuffer = ByteBuffer.allocate(4 + nullDataLen + nonNulls.limit) + newBuffer.order(ByteOrder.nativeOrder()) + newBuffer.putInt(_nullCount).put(_nulls).put(nonNulls) + newBuffer.rewind() + newBuffer } -} \ No newline at end of file +} diff --git a/src/main/scala/shark/memstore2/column/NullableColumnIterator.scala b/src/main/scala/shark/memstore2/column/NullableColumnIterator.scala index 66a3adfa..49e0eb20 100644 --- a/src/main/scala/shark/memstore2/column/NullableColumnIterator.scala +++ b/src/main/scala/shark/memstore2/column/NullableColumnIterator.scala @@ -1,3 +1,20 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column import java.nio.ByteBuffer @@ -9,26 +26,30 @@ import java.nio.ByteOrder * Reading of non nulls is delegated by setting the buffer position to the first * non null. */ -class NullableColumnIterator(delegate: ColumnIterator, buffer: ByteBuffer) extends ColumnIterator { +class NullableColumnIterator(buffer: ByteBuffer) extends ColumnIterator { private var _d: ByteBuffer = _ private var _nullCount: Int = _ private var _nulls = 0 private var _isNull = false - private var _currentNullIndex:Int = _ + private var _currentNullIndex: Int = _ private var _pos = 0 + private var _delegate: ColumnIterator = _ + override def init() { _d = buffer.duplicate() _d.order(ByteOrder.nativeOrder()) _nullCount = _d.getInt() - buffer.position(buffer.position() + _nullCount * 4 + 4) _currentNullIndex = if (_nullCount > 0) _d.getInt() else Integer.MAX_VALUE _pos = 0 - delegate.init() + + // Move the buffer position to the non-null region. + buffer.position(buffer.position() + 4 + _nullCount * 4) + _delegate = ColumnIterator.newNonNullIterator(buffer) } - override def computeNext() { + override def next() { if (_pos == _currentNullIndex) { _nulls += 1 if (_nulls < _nullCount) { @@ -37,12 +58,12 @@ class NullableColumnIterator(delegate: ColumnIterator, buffer: ByteBuffer) exten _isNull = true } else { _isNull = false - delegate.computeNext() + _delegate.next() } _pos += 1 } - - def current: Object = { - if (_isNull) null else delegate.current - } + + override def hasNext: Boolean = (_nulls < _nullCount) || _delegate.hasNext + + def current: Object = if (_isNull) null else _delegate.current } diff --git a/src/main/scala/shark/parse/QueryBlock.scala b/src/main/scala/shark/parse/QueryBlock.scala new file mode 100644 index 00000000..4d79f12a --- /dev/null +++ b/src/main/scala/shark/parse/QueryBlock.scala @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.parse + +import org.apache.hadoop.hive.ql.parse.{QB => HiveQueryBlock} +import org.apache.hadoop.hive.ql.plan.CreateTableDesc +import org.apache.hadoop.hive.ql.plan.TableDesc + +import shark.memstore2.CacheType +import shark.memstore2.CacheType._ + + +/** + * A container for flags and table metadata. Used in SharkSemanticAnalyzer while parsing + * and analyzing ASTs (e.g. in SharkSemanticAnalyzer#analyzeCreateTable()). + */ +class QueryBlock(outerID: String, alias: String, isSubQuery: Boolean) + extends HiveQueryBlock(outerID, alias, isSubQuery) { + + // The CacheType for the table that will be created from CREATE TABLE/CTAS, or updated for an + // INSERT. + var cacheMode = CacheType.NONE + + // Descriptor for the table being updated by an INSERT. + var targetTableDesc: TableDesc = _ + + // Hive's QB uses `tableDesc` to refer to the CreateTableDesc. A direct `createTableDesc` + // makes it easier to differentiate from `_targetTableDesc`. + def createTableDesc: CreateTableDesc = super.getTableDesc + + def createTableDesc_= (desc: CreateTableDesc) = super.setTableDesc(desc) +} diff --git a/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala index a43a4975..3e5f69b2 100644 --- a/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala +++ b/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala @@ -1,24 +1,186 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.parse +import java.util.{HashMap => JavaHashMap} + +import scala.collection.JavaConversions._ + import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.parse.{ASTNode, BaseSemanticAnalyzer, DDLSemanticAnalyzer, HiveParser} +import org.apache.hadoop.hive.ql.exec.TaskFactory +import org.apache.hadoop.hive.ql.parse.ASTNode +import org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer +import org.apache.hadoop.hive.ql.parse.DDLSemanticAnalyzer +import org.apache.hadoop.hive.ql.parse.HiveParser +import org.apache.hadoop.hive.ql.parse.SemanticException +import org.apache.hadoop.hive.ql.plan.{AlterTableDesc, DDLWork} import org.apache.spark.rdd.{UnionRDD, RDD} import shark.{LogHelper, SharkEnv} +import shark.execution.{SharkDDLWork, SparkLoadWork} +import shark.memstore2.{CacheType, MemoryMetadataManager, SharkTblProperties} class SharkDDLSemanticAnalyzer(conf: HiveConf) extends DDLSemanticAnalyzer(conf) with LogHelper { - override def analyzeInternal(node: ASTNode): Unit = { - super.analyzeInternal(node) - //handle drop table query - if (node.getToken().getType() == HiveParser.TOK_DROPTABLE) { - SharkEnv.unpersist(getTableName(node)) + override def analyzeInternal(ast: ASTNode): Unit = { + super.analyzeInternal(ast) + + ast.getToken.getType match { + case HiveParser.TOK_ALTERTABLE_ADDPARTS => { + analyzeAlterTableAddParts(ast) + } + case HiveParser.TOK_ALTERTABLE_DROPPARTS => { + analyzeDropTableOrDropParts(ast) + } + case HiveParser.TOK_ALTERTABLE_RENAME => { + analyzeAlterTableRename(ast) + } + case HiveParser.TOK_ALTERTABLE_PROPERTIES => { + analyzeAlterTableProperties(ast) + } + case HiveParser.TOK_DROPTABLE => { + analyzeDropTableOrDropParts(ast) + } + case _ => Unit + } + } + + /** + * Handle table property changes. + * How Shark-specific changes are handled: + * - "shark.cache": + * If the value evaluated by CacheType#shouldCache() is `true`, then create a SparkLoadTask to + * load the Hive table into memory. + * Set it as a dependent of the Hive DDLTask. A SharkDDLTask counterpart isn't created because + * the HadoopRDD creation and transformation isn't a direct Shark metastore operation + * (unlike the other cases handled in SharkDDLSemantiAnalyzer). * + * If 'false', then create a SharkDDLTask that will delete the table entry in the Shark + * metastore. + */ + def analyzeAlterTableProperties(ast: ASTNode) { + val databaseName = db.getCurrentDatabase() + val tableName = getTableName(ast) + val hiveTable = db.getTable(databaseName, tableName) + val newTblProps = getAlterTblDesc().getProps + val oldTblProps = hiveTable.getParameters + + val oldCacheMode = CacheType.fromString(oldTblProps.get(SharkTblProperties.CACHE_FLAG.varname)) + val newCacheMode = CacheType.fromString(newTblProps.get(SharkTblProperties.CACHE_FLAG.varname)) + if ((oldCacheMode == CacheType.TACHYON && newCacheMode != CacheType.TACHYON) || + (oldCacheMode == CacheType.MEMORY_ONLY && newCacheMode != CacheType.MEMORY_ONLY)) { + throw new SemanticException("""Table %s.%s's 'shark.cache' table property is %s. Only changes + from "'MEMORY' and 'NONE' are supported. Tables stored in TACHYON and MEMORY_ONLY must be + "dropped.""".format(databaseName, tableName, oldCacheMode)) + } else if (newCacheMode == CacheType.MEMORY) { + // The table should be cached (and is not already cached). + val partSpecsOpt = if (hiveTable.isPartitioned) { + val columnNames = hiveTable.getPartCols.map(_.getName) + val partSpecs = db.getPartitions(hiveTable).map { partition => + val partSpec = new JavaHashMap[String, String]() + val values = partition.getValues() + columnNames.zipWithIndex.map { case(name, index) => partSpec.put(name, values(index)) } + partSpec + } + Some(partSpecs) + } else { + None + } + newTblProps.put(SharkTblProperties.CACHE_FLAG.varname, newCacheMode.toString) + val sparkLoadWork = new SparkLoadWork( + databaseName, + tableName, + SparkLoadWork.CommandTypes.NEW_ENTRY, + newCacheMode) + partSpecsOpt.foreach(partSpecs => sparkLoadWork.partSpecs = partSpecs) + rootTasks.head.addDependentTask(TaskFactory.get(sparkLoadWork, conf)) + } else if (newCacheMode == CacheType.NONE) { + // Uncache the table. + SharkEnv.memoryMetadataManager.dropTableFromMemory(db, databaseName, tableName) } } + def analyzeDropTableOrDropParts(ast: ASTNode) { + val databaseName = db.getCurrentDatabase() + val tableName = getTableName(ast) + val hiveTableOpt = Option(db.getTable(databaseName, tableName, false /* throwException */)) + // `hiveTableOpt` can be NONE for a DROP TABLE IF EXISTS command on a nonexistent table. + hiveTableOpt.foreach { hiveTable => + val cacheMode = CacheType.fromString( + hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname)) + // Create a SharkDDLTask only if the table is cached. + if (CacheType.shouldCache(cacheMode)) { + // Hive's DDLSemanticAnalyzer#analyzeInternal() will only populate rootTasks with DDLTasks + // and DDLWorks that contain DropTableDesc objects. + for (ddlTask <- rootTasks) { + val dropTableDesc = ddlTask.getWork.asInstanceOf[DDLWork].getDropTblDesc + val sharkDDLWork = new SharkDDLWork(dropTableDesc) + sharkDDLWork.cacheMode = cacheMode + ddlTask.addDependentTask(TaskFactory.get(sharkDDLWork, conf)) + } + } + } + } + + def analyzeAlterTableAddParts(ast: ASTNode) { + val databaseName = db.getCurrentDatabase() + val tableName = getTableName(ast) + val hiveTable = db.getTable(databaseName, tableName) + val cacheMode = CacheType.fromString( + hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname)) + // Create a SharkDDLTask only if the table is cached. + if (CacheType.shouldCache(cacheMode)) { + // Hive's DDLSemanticAnalyzer#analyzeInternal() will only populate rootTasks with DDLTasks + // and DDLWorks that contain AddPartitionDesc objects. + for (ddlTask <- rootTasks) { + val addPartitionDesc = ddlTask.getWork.asInstanceOf[DDLWork].getAddPartitionDesc + val sharkDDLWork = new SharkDDLWork(addPartitionDesc) + sharkDDLWork.cacheMode = cacheMode + ddlTask.addDependentTask(TaskFactory.get(sharkDDLWork, conf)) + } + } + } + + private def analyzeAlterTableRename(astNode: ASTNode) { + val databaseName = db.getCurrentDatabase() + val oldTableName = getTableName(astNode) + val hiveTable = db.getTable(databaseName, oldTableName) + val cacheMode = CacheType.fromString(hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname)) + if (CacheType.shouldCache(cacheMode)) { + val alterTableDesc = getAlterTblDesc() + val sharkDDLWork = new SharkDDLWork(alterTableDesc) + sharkDDLWork.cacheMode = cacheMode + rootTasks.head.addDependentTask(TaskFactory.get(sharkDDLWork, conf)) + } + } + + private def getAlterTblDesc(): AlterTableDesc = { + // Hive's DDLSemanticAnalyzer#analyzeInternal() will only populate rootTasks with a DDLTask + // and DDLWork that contains an AlterTableDesc. + assert(rootTasks.size == 1) + val ddlTask = rootTasks.head + val ddlWork = ddlTask.getWork + assert(ddlWork.isInstanceOf[DDLWork]) + ddlWork.asInstanceOf[DDLWork].getAlterTblDesc + } + private def getTableName(node: ASTNode): String = { BaseSemanticAnalyzer.getUnescapedName(node.getChild(0).asInstanceOf[ASTNode]) } -} \ No newline at end of file +} diff --git a/src/main/scala/shark/parse/SharkExplainSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkExplainSemanticAnalyzer.scala index c8d69322..e139ac27 100755 --- a/src/main/scala/shark/parse/SharkExplainSemanticAnalyzer.scala +++ b/src/main/scala/shark/parse/SharkExplainSemanticAnalyzer.scala @@ -19,13 +19,11 @@ package shark.parse import java.io.Serializable import java.util.ArrayList -import java.util.List import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.parse._ -import org.apache.hadoop.hive.ql.plan.ExplainWork import shark.execution.SharkExplainWork diff --git a/src/main/scala/shark/parse/SharkLoadSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkLoadSemanticAnalyzer.scala new file mode 100644 index 00000000..fc32dbd7 --- /dev/null +++ b/src/main/scala/shark/parse/SharkLoadSemanticAnalyzer.scala @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.parse + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.exec.{CopyTask, MoveTask, TaskFactory} +import org.apache.hadoop.hive.ql.metadata.{Partition, Table => HiveTable} +import org.apache.hadoop.hive.ql.parse.{ASTNode, BaseSemanticAnalyzer, LoadSemanticAnalyzer} +import org.apache.hadoop.hive.ql.plan._ + +import shark.{LogHelper, SharkEnv} +import shark.execution.SparkLoadWork +import shark.memstore2.{CacheType, SharkTblProperties} + + +class SharkLoadSemanticAnalyzer(conf: HiveConf) extends LoadSemanticAnalyzer(conf) { + + override def analyzeInternal(ast: ASTNode): Unit = { + // Delegate to the LoadSemanticAnalyzer parent for error checking the source path formatting. + super.analyzeInternal(ast) + + // Children of the AST root created for a LOAD DATA [LOCAL] INPATH ... statement are, in order: + // 1. node containing the path specified by INPATH. + // 2. internal TOK_TABNAME node that contains the table's name. + // 3. (optional) node representing the LOCAL modifier. + val tableASTNode = ast.getChild(1).asInstanceOf[ASTNode] + val tableName = getTableName(tableASTNode) + val hiveTable = db.getTable(tableName) + val cacheMode = CacheType.fromString( + hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname)) + + if (CacheType.shouldCache(cacheMode)) { + // Find the arguments needed to instantiate a SparkLoadWork. + val tableSpec = new BaseSemanticAnalyzer.tableSpec(db, conf, tableASTNode) + val hiveTable = tableSpec.tableHandle + val moveTask = getMoveTask() + val partSpecOpt = Option(tableSpec.getPartSpec) + val sparkLoadWork = SparkLoadWork( + db, + conf, + hiveTable, + partSpecOpt, + isOverwrite = moveTask.getWork.getLoadTableWork.getReplace) + + // Create a SparkLoadTask that will read from the table's data directory. Make it a dependent + // task of the LoadTask so that it's executed only if the LoadTask executes successfully. + moveTask.addDependentTask(TaskFactory.get(sparkLoadWork, conf)) + } + } + + private def getMoveTask(): MoveTask = { + assert(rootTasks.size == 1) + + // If the execution is local, then the root task is a CopyTask with a MoveTask child. + // Otherwise, the root is a MoveTask. + var rootTask = rootTasks.head + val moveTask = if (rootTask.isInstanceOf[CopyTask]) { + val firstChildTask = rootTask.getChildTasks.head + assert(firstChildTask.isInstanceOf[MoveTask]) + firstChildTask + } else { + rootTask + } + + // In Hive, LoadTableDesc is referred to as LoadTableWork ... + moveTask.asInstanceOf[MoveTask] + } + + private def getTableName(node: ASTNode): String = { + BaseSemanticAnalyzer.getUnescapedName(node.getChild(0).asInstanceOf[ASTNode]) + } +} diff --git a/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala index 199102bc..f3a7b49a 100755 --- a/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala +++ b/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala @@ -17,30 +17,33 @@ package shark.parse -import java.lang.reflect.Method import java.util.ArrayList import java.util.{List => JavaList} +import java.util.{Map => JavaMap} import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{FieldSchema, MetaException} import org.apache.hadoop.hive.metastore.Warehouse -import org.apache.hadoop.hive.ql.exec.{DDLTask, FetchTask, MoveTask, TaskFactory} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, MetaException} +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.ql.exec.{DDLTask, FetchTask} import org.apache.hadoop.hive.ql.exec.{FileSinkOperator => HiveFileSinkOperator} -import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.hadoop.hive.ql.exec.MoveTask +import org.apache.hadoop.hive.ql.exec.{Operator => HiveOperator} +import org.apache.hadoop.hive.ql.exec.TaskFactory +import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException} import org.apache.hadoop.hive.ql.optimizer.Optimizer import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.spark.storage.StorageLevel - -import shark.{CachedTableRecovery, LogHelper, SharkConfVars, SharkEnv, Utils} -import shark.execution.{HiveOperator, Operator, OperatorFactory, RDDUtils, ReduceSinkOperator, - SparkWork, TerminalOperator} -import shark.memstore2.{CacheType, ColumnarSerDe, MemoryMetadataManager} +import shark.{LogHelper, SharkConfVars, SharkEnv, Utils} +import shark.execution.{HiveDesc, Operator, OperatorFactory, RDDUtils, ReduceSinkOperator} +import shark.execution.{SharkDDLWork, SparkLoadWork, SparkWork, TerminalOperator} +import shark.memstore2.{CacheType, ColumnarSerDe, LazySimpleSerDeWrapper, MemoryMetadataManager} +import shark.memstore2.{MemoryTable, PartitionedMemoryTable, SharkTblProperties, TableRecovery} /** @@ -60,73 +63,67 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with override def getResultSchema() = _resSchema /** - * Override SemanticAnalyzer.analyzeInternal to handle CTAS caching. + * Override SemanticAnalyzer.analyzeInternal to handle CTAS caching and INSERT updates. + * + * Unified views: + * For CTAS and INSERT INTO/OVERWRITE the generated Shark query plan matches the one + * created if the target table were not cached. Disk => memory loading is done by a + * SparkLoadTask that executes _after_ all other tasks (SparkTask, Hive MoveTasks) finish + * executing. For INSERT INTO, the SparkLoadTask will be able to determine, using a path filter + * based on a snapshot of the table/partition data directory taken in genMapRedTasks(), new files + * that should be loaded into the cache. For CTAS, a path filter isn't used - everything in the + * data directory is loaded into the cache. + * + * Non-unified views (i.e., the cached table content is memory-only): + * The query plan's FileSinkOperator is replaced by a MemoryStoreSinkOperator. The + * MemoryStoreSinkOperator creates a new table (or partition) entry in the Shark metastore + * for CTAS, and creates UnionRDDs for INSERT INTO commands. */ override def analyzeInternal(ast: ASTNode): Unit = { reset() - val qb = new QB(null, null, false) + val qb = new QueryBlock(null, null, false) val pctx = getParseContext() pctx.setQB(qb) pctx.setParseTree(ast) init(pctx) + // The ASTNode that will be analyzed by SemanticAnalzyer#doPhase1(). var child: ASTNode = ast - logInfo("Starting Shark Semantic Analysis") + logDebug("Starting Shark Semantic Analysis") //TODO: can probably reuse Hive code for this - // analyze create table command - var cacheMode = CacheType.none - var isCTAS = false var shouldReset = false - if (ast.getToken().getType() == HiveParser.TOK_CREATETABLE) { + val astTokenType = ast.getToken().getType() + if (astTokenType == HiveParser.TOK_CREATEVIEW || astTokenType == HiveParser.TOK_ANALYZE) { + // Delegate create view and analyze to Hive. super.analyzeInternal(ast) - for (ch <- ast.getChildren) { - ch.asInstanceOf[ASTNode].getToken.getType match { - case HiveParser.TOK_QUERY => { - isCTAS = true - child = ch.asInstanceOf[ASTNode] - } - case _ => - Unit - } - } - - // If the table descriptor can be null if the CTAS has an - // "if not exists" condition. - val td = getParseContext.getQB.getTableDesc - if (!isCTAS || td == null) { - return - } else { - val checkTableName = SharkConfVars.getBoolVar(conf, SharkConfVars.CHECK_TABLENAME_FLAG) - val cacheType = CacheType.fromString(td.getTblProps().get("shark.cache")) - if (cacheType == CacheType.heap || - (td.getTableName.endsWith("_cached") && checkTableName)) { - cacheMode = CacheType.heap - td.getTblProps().put("shark.cache", cacheMode.toString) - } else if (cacheType == CacheType.tachyon || - (td.getTableName.endsWith("_tachyon") && checkTableName)) { - cacheMode = CacheType.tachyon - td.getTblProps().put("shark.cache", cacheMode.toString) + return + } else if (astTokenType == HiveParser.TOK_CREATETABLE) { + // Use Hive to do a first analysis pass. + super.analyzeInternal(ast) + // Do post-Hive analysis of the CREATE TABLE (e.g detect caching mode). + analyzeCreateTable(ast, qb) match { + case Some(queryStmtASTNode) => { + // Set the 'child' to reference the SELECT statement root node, with is a + // HiveParer.HIVE_QUERY. + child = queryStmtASTNode + // Hive's super.analyzeInternal() might generate MapReduce tasks. Avoid executing those + // tasks by reset()-ing some Hive SemanticAnalyzer state after doPhase1() is called below. + shouldReset = true } - - if (CacheType.shouldCache(cacheMode)) { - td.setSerName(classOf[ColumnarSerDe].getName) + case None => { + // Done with semantic analysis if the CREATE TABLE statement isn't a CTAS. + return } - - qb.setTableDesc(td) - shouldReset = true } } else { SessionState.get().setCommandType(HiveOperation.QUERY) } - // Delegate create view and analyze to Hive. - val astTokenType = ast.getToken().getType() - if (astTokenType == HiveParser.TOK_CREATEVIEW || astTokenType == HiveParser.TOK_ANALYZE) { - return super.analyzeInternal(ast) - } + // Invariant: At this point, the command will execute a query (i.e., its AST contains a + // HiveParser.TOK_QUERY node). // Continue analyzing from the child ASTNode. if (!doPhase1(child, qb, initPhase1Ctx())) { @@ -136,12 +133,14 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with // Used to protect against recursive views in getMetaData(). SharkSemanticAnalyzer.viewsExpandedField.set(this, new ArrayList[String]()) - logInfo("Completed phase 1 of Shark Semantic Analysis") + logDebug("Completed phase 1 of Shark Semantic Analysis") getMetaData(qb) - logInfo("Completed getting MetaData in Shark Semantic Analysis") + logDebug("Completed getting MetaData in Shark Semantic Analysis") // Reset makes sure we don't run the mapred jobs generated by Hive. - if (shouldReset) reset() + if (shouldReset) { + reset() + } // Save the result schema derived from the sink operator produced // by genPlan. This has the correct column names, which clients @@ -169,61 +168,89 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with // TODO: clean the following code. It's too messy to understand... val terminalOpSeq = { - if (qb.getParseInfo.isInsertToTable && !qb.isCTAS) { + val qbParseInfo = qb.getParseInfo + if (qbParseInfo.isInsertToTable && !qb.isCTAS) { + // Handle INSERT. There can be multiple Hive sink operators if the single command comprises + // multiple INSERTs. hiveSinkOps.map { hiveSinkOp => - val tableName = hiveSinkOp.asInstanceOf[HiveFileSinkOperator].getConf().getTableInfo() - .getTableName() - + val tableDesc = hiveSinkOp.asInstanceOf[HiveFileSinkOperator].getConf().getTableInfo() + val tableName = tableDesc.getTableName if (tableName == null || tableName == "") { // If table name is empty, it is an INSERT (OVERWRITE) DIRECTORY. OperatorFactory.createSharkFileOutputPlan(hiveSinkOp) } else { // Otherwise, check if we are inserting into a table that was cached. - val cachedTableName = tableName.split('.')(1) // Ignore the database name - SharkEnv.memoryMetadataManager.get(cachedTableName) match { - case Some(rdd) => { - if (hiveSinkOps.size == 1) { - // If useUnionRDD is false, the sink op is for INSERT OVERWRITE. - val useUnionRDD = qb.getParseInfo.isInsertIntoTable(cachedTableName) - val storageLevel = RDDUtils.getStorageLevelOfCachedTable(rdd) + val tableNameSplit = tableName.split('.') // Split from 'databaseName.tableName' + val cachedTableName = tableNameSplit(1) + val databaseName = tableNameSplit(0) + val hiveTable = Hive.get().getTable(databaseName, tableName) + val cacheMode = CacheType.fromString( + hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname)) + if (CacheType.shouldCache(cacheMode)) { + if (hiveSinkOps.size == 1) { + // INSERT INTO or OVERWRITE update on a cached table. + qb.targetTableDesc = tableDesc + // If isInsertInto is true, the sink op is for INSERT INTO. + val isInsertInto = qbParseInfo.isInsertIntoTable(cachedTableName) + val isPartitioned = hiveTable.isPartitioned + var hivePartitionKeyOpt = if (isPartitioned) { + Some(SharkSemanticAnalyzer.getHivePartitionKey(qb)) + } else { + None + } + if (cacheMode == CacheType.MEMORY) { + // The table being updated is stored in memory and backed by disk, a + // SparkLoadTask will be created by the genMapRedTasks() call below. Set fields + // in `qb` that will be needed. + qb.cacheMode = cacheMode + qb.targetTableDesc = tableDesc + OperatorFactory.createSharkFileOutputPlan(hiveSinkOp) + } else { OperatorFactory.createSharkMemoryStoreOutputPlan( hiveSinkOp, cachedTableName, - storageLevel, - _resSchema.size, // numColumns - cacheMode == CacheType.tachyon, // use tachyon - useUnionRDD) - } else { - throw new SemanticException( - "Shark does not support updating cached table(s) with multiple INSERTs") + databaseName, + _resSchema.size, /* numColumns */ + hivePartitionKeyOpt, + cacheMode, + isInsertInto) } + } else { + throw new SemanticException( + "Shark does not support updating cached table(s) with multiple INSERTs") } - case None => OperatorFactory.createSharkFileOutputPlan(hiveSinkOp) + } else { + OperatorFactory.createSharkFileOutputPlan(hiveSinkOp) } } } } else if (hiveSinkOps.size == 1) { - // For a single output, we have the option of choosing the output - // destination (e.g. CTAS with table property "shark.cache" = "true"). Seq { - if (qb.isCTAS && qb.getTableDesc != null && CacheType.shouldCache(cacheMode)) { - val storageLevel = MemoryMetadataManager.getStorageLevelFromString( - qb.getTableDesc().getTblProps.get("shark.cache.storageLevel")) - qb.getTableDesc().getTblProps().put(CachedTableRecovery.QUERY_STRING, ctx.getCmd()) - OperatorFactory.createSharkMemoryStoreOutputPlan( - hiveSinkOps.head, - qb.getTableDesc.getTableName, - storageLevel, - _resSchema.size, // numColumns - cacheMode == CacheType.tachyon, // use tachyon - false) + // For a single output, we have the option of choosing the output + // destination (e.g. CTAS with table property "shark.cache" = "true"). + if (qb.isCTAS && qb.createTableDesc != null && CacheType.shouldCache(qb.cacheMode)) { + // The table being created from CTAS should be cached. + val tblProps = qb.createTableDesc.getTblProps + if (qb.cacheMode == CacheType.MEMORY) { + // Save the preferred storage level, since it's needed to create a SparkLoadTask in + // genMapRedTasks(). + OperatorFactory.createSharkFileOutputPlan(hiveSinkOps.head) + } else { + OperatorFactory.createSharkMemoryStoreOutputPlan( + hiveSinkOps.head, + qb.createTableDesc.getTableName, + qb.createTableDesc.getDatabaseName, + numColumns = _resSchema.size, + hivePartitionKeyOpt = None, + qb.cacheMode, + isInsertInto = false) + } } else if (pctx.getContext().asInstanceOf[QueryContext].useTableRddSink && !qb.isCTAS) { OperatorFactory.createSharkRddOutputPlan(hiveSinkOps.head) } else { OperatorFactory.createSharkFileOutputPlan(hiveSinkOps.head) } } - // A hack for the query plan dashboard to get the query plan. This was // done for SIGMOD demo. Turn it off by default. //shark.dashboard.QueryPlanDashboardHandler.terminalOperator = terminalOp @@ -237,15 +264,14 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with SharkSemanticAnalyzer.breakHivePlanByStages(terminalOpSeq) genMapRedTasks(qb, pctx, terminalOpSeq) - logInfo("Completed plan generation") + logDebug("Completed plan generation") } /** * Generate tasks for executing the query, including the SparkTask to do the * select, the MoveTask for updates, and the DDLTask for CTAS. */ - def genMapRedTasks(qb: QB, pctx: ParseContext, terminalOps: Seq[TerminalOperator]) { - + def genMapRedTasks(qb: QueryBlock, pctx: ParseContext, terminalOps: Seq[TerminalOperator]) { // Create the spark task. terminalOps.foreach { terminalOp => val task = TaskFactory.get(new SparkWork(pctx, terminalOp, _resSchema), conf) @@ -253,6 +279,7 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with } if (qb.getIsQuery) { + // Note: CTAS isn't considered a query - it's handled in the 'else' block below. // Configure FetchTask (used for fetching results to CLIDriver). val loadWork = getParseContext.getLoadFileWork.get(0) val cols = loadWork.getColumns @@ -268,9 +295,10 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with setFetchTask(fetchTask) } else { - // Configure MoveTasks for table updates (e.g. CTAS, INSERT). + // Configure MoveTasks for CTAS, INSERT. val mvTasks = new ArrayList[MoveTask]() + // For CTAS, `fileWork` contains a single LoadFileDesc (called "LoadFileWork" in Hive). val fileWork = getParseContext.getLoadFileWork val tableWork = getParseContext.getLoadTableWork tableWork.foreach { ltd => @@ -280,13 +308,14 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with fileWork.foreach { lfd => if (qb.isCTAS) { + // For CTAS, `lfd.targetDir` references the data directory of the table being created. var location = qb.getTableDesc.getLocation if (location == null) { try { - val dumpTable = db.newTable(qb.getTableDesc.getTableName) + val tableToCreate = db.newTable(qb.getTableDesc.getTableName) val wh = new Warehouse(conf) - location = wh.getTablePath(db.getDatabase(dumpTable.getDbName()), dumpTable - .getTableName()).toString; + location = wh.getTablePath(db.getDatabase(tableToCreate.getDbName()), tableToCreate + .getTableName()).toString; } catch { case e: HiveException => throw new SemanticException(e) case e: MetaException => throw new SemanticException(e) @@ -299,9 +328,13 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with new MoveWork(null, null, null, lfd, false), conf).asInstanceOf[MoveTask]) } - // The move task depends on all root tasks. In the case of multi outputs, + // The move task depends on all root tasks. In the case of multiple outputs, // the moves are only started once all outputs are executed. - val hiveFileSinkOp = terminalOps.head.hiveOp + // Note: For a CTAS for a memory-only cached table, a MoveTask is still added as a child of + // the main SparkTask. However, there no effects from its execution, since the SELECT query + // output is piped to Shark's in-memory columnar storage builder, instead of a Hive tmp + // directory. + // TODO(harvey): Don't create a MoveTask in this case. mvTasks.foreach { moveTask => rootTasks.foreach { rootTask => rootTask.addDependentTask(moveTask) @@ -321,6 +354,44 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with } */ } + + if (qb.cacheMode == CacheType.MEMORY) { + // Create a SparkLoadTask used to scan and load disk contents into the cache. + val sparkLoadWork = if (qb.isCTAS) { + // For cached tables, Shark-specific table properties should be set in + // analyzeCreateTable(). + val tblProps = qb.createTableDesc.getTblProps + + // No need to create a filter, since the entire table data directory should be loaded, nor + // pass partition specifications, since partitioned tables can't be created from CTAS. + val sparkLoadWork = new SparkLoadWork( + qb.createTableDesc.getDatabaseName, + qb.createTableDesc.getTableName, + SparkLoadWork.CommandTypes.NEW_ENTRY, + qb.cacheMode) + sparkLoadWork + } else { + // Split from 'databaseName.tableName' + val tableNameSplit = qb.targetTableDesc.getTableName.split('.') + val databaseName = tableNameSplit(0) + val cachedTableName = tableNameSplit(1) + val hiveTable = db.getTable(databaseName, cachedTableName) + // None if the table isn't partitioned, or if the partition specified doesn't exist. + val partSpecOpt = Option(qb.getMetaData.getDestPartitionForAlias( + qb.getParseInfo.getClauseNamesForDest.head)).map(_.getSpec) + SparkLoadWork( + db, + conf, + hiveTable, + partSpecOpt, + isOverwrite = !qb.getParseInfo.isInsertIntoTable(cachedTableName)) + } + // Add a SparkLoadTask as a dependent of all MoveTasks, so that when executed, the table's + // (or table partition's) data directory will already contain updates that should be + // loaded into memory. + val sparkLoadTask = TaskFactory.get(sparkLoadWork, conf) + mvTasks.foreach(_.addDependentTask(sparkLoadTask)) + } } // For CTAS, generate a DDL task to create the table. This task should be a @@ -344,11 +415,108 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with rootTasks.head.addDependentTask(crtTblTask) } } + + def analyzeCreateTable(rootAST: ASTNode, queryBlock: QueryBlock): Option[ASTNode] = { + // If we detect that the CREATE TABLE is part of a CTAS, then this is set to the root node of + // the query command (i.e., the root node of the SELECT statement). + var queryStmtASTNode: Option[ASTNode] = None + + // TODO(harvey): We might be able to reuse the QB passed into this method, as long as it was + // created after the super.analyzeInternal() call. That QB and the createTableDesc + // should have everything (e.g. isCTAS(), partCols). Note that the QB might not be + // accessible from getParseContext(), since the SemanticAnalyzer#analyzeInternal() + // doesn't set (this.qb = qb) for a non-CTAS. + // True if the command is a CREATE TABLE, but not a CTAS. + var isRegularCreateTable = true + var isHivePartitioned = false + + for (ch <- rootAST.getChildren) { + ch.asInstanceOf[ASTNode].getToken.getType match { + case HiveParser.TOK_QUERY => { + isRegularCreateTable = false + queryStmtASTNode = Some(ch.asInstanceOf[ASTNode]) + } + case _ => Unit + } + } + + var ddlTasks: Seq[DDLTask] = Nil + val createTableDesc = if (isRegularCreateTable) { + // Unfortunately, we have to comb the root tasks because for CREATE TABLE, + // SemanticAnalyzer#analyzeCreateTable() does't set the CreateTableDesc in its QB. + ddlTasks = rootTasks.filter(_.isInstanceOf[DDLTask]).asInstanceOf[Seq[DDLTask]] + if (ddlTasks.isEmpty) null else ddlTasks.head.getWork.getCreateTblDesc + } else { + getParseContext.getQB.getTableDesc + } + + // Update the QueryBlock passed into this method. + // TODO(harvey): Remove once the TODO above is fixed. + queryBlock.setTableDesc(createTableDesc) + + // 'createTableDesc' is NULL if there is an IF NOT EXISTS condition and the target table + // already exists. + if (createTableDesc != null) { + val tableName = createTableDesc.getTableName + val checkTableName = SharkConfVars.getBoolVar(conf, SharkConfVars.CHECK_TABLENAME_FLAG) + // Note that the CreateTableDesc's table properties are Java Maps, but the TableDesc's table + // properties, which are used during execution, are Java Properties. + val createTableProperties: JavaMap[String, String] = createTableDesc.getTblProps() + + // There are two cases that will enable caching: + // 1) Table name includes "_cached" or "_tachyon". + // 2) The "shark.cache" table property is "true", or the string representation of a supported + // cache mode (memory, memory-only, Tachyon). + var cacheMode = CacheType.fromString( + createTableProperties.get(SharkTblProperties.CACHE_FLAG.varname)) + if (checkTableName) { + if (tableName.endsWith("_cached")) { + cacheMode = CacheType.MEMORY + } else if (tableName.endsWith("_tachyon")) { + cacheMode = CacheType.TACHYON + } + } + + // Continue planning based on the 'cacheMode' read. + val shouldCache = CacheType.shouldCache(cacheMode) + if (shouldCache) { + if (cacheMode == CacheType.MEMORY_ONLY || cacheMode == CacheType.TACHYON) { + val serDeName = createTableDesc.getSerName + if (serDeName == null || serDeName == classOf[LazySimpleSerDe].getName) { + // Hive's SemanticAnalyzer optimizes based on checks for LazySimpleSerDe, which causes + // casting exceptions for cached table scans during runtime. Use a simple SerDe wrapper + // to guard against these optimizations. + createTableDesc.setSerName(classOf[LazySimpleSerDeWrapper].getName) + } + } + createTableProperties.put(SharkTblProperties.CACHE_FLAG.varname, cacheMode.toString) + } + + // For CTAS ('isRegularCreateTable' is false), the MemoryStoreSinkOperator creates a new + // table metadata entry in the MemoryMetadataManager. The SparkTask that encloses the + // MemoryStoreSinkOperator will have a child Hive DDLTask, which creates a new table metadata + // entry in the Hive metastore. See genMapRedTasks() for SparkTask creation. + if (isRegularCreateTable && shouldCache) { + // In Hive, a CREATE TABLE command is handled by a DDLTask, created by + // SemanticAnalyzer#analyzeCreateTable(), in 'rootTasks'. The DDL tasks' execution succeeds + // only if the CREATE TABLE is valid. So, hook a SharkDDLTask as a child of the Hive DDLTask + // so that Shark metadata is updated only if the Hive task execution is successful. + val hiveDDLTask = ddlTasks.head; + val sharkDDLWork = new SharkDDLWork(createTableDesc) + sharkDDLWork.cacheMode = cacheMode + hiveDDLTask.addDependentTask(TaskFactory.get(sharkDDLWork, conf)) + } + + queryBlock.cacheMode = cacheMode + queryBlock.setTableDesc(createTableDesc) + } + queryStmtASTNode + } + } object SharkSemanticAnalyzer extends LogHelper { - /** * The reflection object used to invoke convertRowSchemaToViewSchema. */ @@ -363,13 +531,22 @@ object SharkSemanticAnalyzer extends LogHelper { private val viewsExpandedField = classOf[SemanticAnalyzer].getDeclaredField("viewsExpanded") viewsExpandedField.setAccessible(true) + private def getHivePartitionKey(qb: QB): String = { + val selectClauseKey = qb.getParseInfo.getClauseNamesForDest.head + val destPartition = qb.getMetaData.getDestPartitionForAlias(selectClauseKey) + val partitionColumns = destPartition.getTable.getPartCols.map(_.getName) + val partitionColumnToValue = destPartition.getSpec + MemoryMetadataManager.makeHivePartitionKeyStr(partitionColumns, partitionColumnToValue) + } + /** * Given a Hive top operator (e.g. TableScanOperator), find all the file sink * operators (aka file output operator). */ - private def findAllHiveFileSinkOperators(op: HiveOperator): Seq[HiveOperator] = { + private def findAllHiveFileSinkOperators(op: HiveOperator[_<: HiveDesc]) + : Seq[HiveOperator[_<: HiveDesc]] = { if (op.getChildOperators() == null || op.getChildOperators().size() == 0) { - Seq[HiveOperator](op) + Seq[HiveOperator[_<: HiveDesc]](op) } else { op.getChildOperators().flatMap(findAllHiveFileSinkOperators(_)).distinct } @@ -384,7 +561,7 @@ object SharkSemanticAnalyzer extends LogHelper { */ private def breakHivePlanByStages(terminalOps: Seq[TerminalOperator]) = { val reduceSinks = new scala.collection.mutable.HashSet[ReduceSinkOperator] - val queue = new scala.collection.mutable.Queue[Operator[_]] + val queue = new scala.collection.mutable.Queue[Operator[_ <: HiveDesc]] queue ++= terminalOps while (!queue.isEmpty) { @@ -399,15 +576,5 @@ object SharkSemanticAnalyzer extends LogHelper { } logDebug("Found %d ReduceSinkOperator's.".format(reduceSinks.size)) - - reduceSinks.foreach { op => - val hiveOp = op.asInstanceOf[Operator[HiveOperator]].hiveOp - if (hiveOp.getChildOperators() != null) { - hiveOp.getChildOperators().foreach { child => - logDebug("Removing child %s from %s".format(child, hiveOp)) - hiveOp.removeChild(child) - } - } - } } } diff --git a/src/main/scala/shark/parse/SharkSemanticAnalyzerFactory.scala b/src/main/scala/shark/parse/SharkSemanticAnalyzerFactory.scala index 91215988..721ce115 100755 --- a/src/main/scala/shark/parse/SharkSemanticAnalyzerFactory.scala +++ b/src/main/scala/shark/parse/SharkSemanticAnalyzerFactory.scala @@ -19,7 +19,7 @@ package shark.parse import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.parse.{ASTNode, BaseSemanticAnalyzer, DDLSemanticAnalyzer, - SemanticAnalyzerFactory, ExplainSemanticAnalyzer, SemanticAnalyzer} + ExplainSemanticAnalyzer, LoadSemanticAnalyzer, SemanticAnalyzerFactory, SemanticAnalyzer} import shark.SharkConfVars @@ -30,18 +30,19 @@ object SharkSemanticAnalyzerFactory { * Return a semantic analyzer for the given ASTNode. */ def get(conf: HiveConf, tree:ASTNode): BaseSemanticAnalyzer = { - val baseSem = SemanticAnalyzerFactory.get(conf, tree) - - if (baseSem.isInstanceOf[SemanticAnalyzer]) { - new SharkSemanticAnalyzer(conf) - } else if (baseSem.isInstanceOf[ExplainSemanticAnalyzer] && - SharkConfVars.getVar(conf, SharkConfVars.EXPLAIN_MODE) == "shark") { - new SharkExplainSemanticAnalyzer(conf) - } else if (baseSem.isInstanceOf[DDLSemanticAnalyzer]) { - new SharkDDLSemanticAnalyzer(conf) - } else { - baseSem + val explainMode = SharkConfVars.getVar(conf, SharkConfVars.EXPLAIN_MODE) == "shark" + + SemanticAnalyzerFactory.get(conf, tree) match { + case _: SemanticAnalyzer => + new SharkSemanticAnalyzer(conf) + case _: ExplainSemanticAnalyzer if explainMode => + new SharkExplainSemanticAnalyzer(conf) + case _: DDLSemanticAnalyzer => + new SharkDDLSemanticAnalyzer(conf) + case _: LoadSemanticAnalyzer => + new SharkLoadSemanticAnalyzer(conf) + case sem: BaseSemanticAnalyzer => + sem } } } - diff --git a/src/main/scala/shark/repl/Main.scala b/src/main/scala/shark/repl/Main.scala index 1fa22da5..890a74ef 100755 --- a/src/main/scala/shark/repl/Main.scala +++ b/src/main/scala/shark/repl/Main.scala @@ -17,11 +17,21 @@ package shark.repl +import org.apache.hadoop.hive.common.LogUtils +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException + + /** * Shark's REPL entry point. */ object Main { + try { + LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => // Ignore the error. + } + private var _interp: SharkILoop = null def interp = _interp diff --git a/src/main/scala/shark/tachyon/TachyonUtil.scala b/src/main/scala/shark/tachyon/TachyonUtil.scala index 3a50eead..25207d91 100644 --- a/src/main/scala/shark/tachyon/TachyonUtil.scala +++ b/src/main/scala/shark/tachyon/TachyonUtil.scala @@ -22,8 +22,7 @@ import java.util.BitSet import org.apache.spark.rdd.RDD -import shark.memstore2.TablePartition - +import shark.memstore2.{TablePartition, TablePartitionStats} /** @@ -32,17 +31,27 @@ import shark.memstore2.TablePartition * even without Tachyon jars. */ abstract class TachyonUtil { + def pushDownColumnPruning(rdd: RDD[_], columnUsed: BitSet): Boolean def tachyonEnabled(): Boolean - def tableExists(tableName: String): Boolean + def tableExists(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean + + def dropTable(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean - def dropTable(tableName: String): Boolean + def createDirectory(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean - def getTableMetadata(tableName: String): ByteBuffer + def renameDirectory(oldName: String, newName: String): Boolean - def createRDD(tableName: String): RDD[TablePartition] + def createRDD( + tableKey: String, + hivePartitionKeyOpt: Option[String] + ): Seq[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] - def createTableWriter(tableName: String, numColumns: Int): TachyonTableWriter + def createTableWriter( + tableKey: String, + hivePartitionKey: Option[String], + numColumns: Int + ): TachyonTableWriter } diff --git a/src/main/scala/shark/tgf/TGF.scala b/src/main/scala/shark/tgf/TGF.scala new file mode 100644 index 00000000..b57d4053 --- /dev/null +++ b/src/main/scala/shark/tgf/TGF.scala @@ -0,0 +1,303 @@ +/* + * Copyright (C) 2013 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.tgf + +import java.sql.Timestamp +import java.util.Date + +import scala.language.implicitConversions +import scala.reflect.{classTag, ClassTag} +import scala.util.parsing.combinator._ + +import org.apache.spark.rdd.RDD + +import shark.api._ +import shark.SharkContext +import java.lang.reflect.Method + +/** + * This object is responsible for handling TGF (Table Generating Function) commands. + * + * {{{ + * -- TGF Commands -- + * GENERATE tgfname(param1, param2, ... , param_n) + * GENERATE tgfname(param1, param2, ... , param_n) AS tablename + * }}} + * + * Parameters can either be of primitive types, eg int, or of type RDD[Product]. TGF.execute() + * will use reflection looking for an object of name "tgfname", invoking apply() with the + * primitive values. If the type of a parameter to apply() is RDD[Product], it will assume the + * parameter is the name of a table, which it will turn into an RDD before invoking apply(). + * + * For example, "GENERATE MyObj(25, emp)" will invoke + * MyObj.apply(25, sc.sql2rdd("select * from emp")) + * , assuming the TGF object (MyObj) has an apply function that takes an int and an RDD[Product]. + * + * The "as" version of the command saves the output in a new table named "tablename", + * whereas the other version returns a ResultSet. + * + * -- Defining TGF objects -- + * TGF objects need to have an apply() function and take an arbitrary number of either primitive + * or RDD[Product] typed parameters. The apply() function should either return an RDD[Product] + * or RDDSchema. When the former case is used, the returned table's schema and column names need + * to be defined through a Java annotation called @Schema. Here is a short example: + * {{{ + * object MyTGF1 { + * \@Schema(spec = "name string, age int") + * def apply(table1: RDD[(String, String, Int)]): RDD[Product] = { + * // code that manipulates table1 and returns a new RDD of tuples + * } + * } + * }}} + * + * Sometimes, the TGF dynamically determines the number or types of columns returned. In this case, + * the TGF can use the RDDSchema return type instead of Java annotations. RDDSchema simply contains + * a schema string and an RDD of results. For example: + * {{{ + * object MyTGF2 { + * \@Schema(spec = "name string, age int") + * def apply(table1: RDD[(String, String, Int)]): RDD[Product] = { + * // code that manipulates table1 and creates a result rdd + * return RDDSchema(rdd.asInstanceOf[RDD[Seq[_]]], "name string, age int") + * } + * } + * }}} + * + * Sometimes the TGF needs to internally make SQL calls. For that, it needs access to a + * SharkContext object. Therefore, + * {{{ + * def apply(sc: SharkContext, table1: RDD[(String, String, Int)]): RDD[Product] = { + * // code that can use sc, for example by calling sc.sql2rdd() + * // code that manipulates table1 and returns a new RDD of tuples + * } + * }}} + */ + +object TGF { + private val parser = new TGFParser + + /** + * Executes a TGF command and gives back the ResultSet. + * Mainly to be used from SharkContext (e.g. runSql()) + * + * @param sql TGF command, e.g. "GENERATE name(params) AS tablename" + * @param sc SharkContext + * @return ResultSet containing the results of the command + */ + def execute(sql: String, sc: SharkContext): ResultSet = { + val ast = parser.parseAll(parser.tgf, sql).getOrElse( + throw new QueryExecutionException("TGF parse error: "+ sql)) + + val (tableNameOpt, tgfName, params) = ast match { + case (tgfName, params) => + (None, tgfName.asInstanceOf[String], params.asInstanceOf[List[String]]) + case (tableName, tgfName, params) => + (Some(tableName.asInstanceOf[String]), tgfName.asInstanceOf[String], + params.asInstanceOf[List[String]]) + } + + val obj = reflectInvoke(tgfName, params, sc) + val (rdd, schema) = getSchema(obj, tgfName) + + val (sharkSchema, resultArr) = tableNameOpt match { + case Some(tableName) => // materialize results + val helper = new RDDTableFunctions(rdd, schema.map { case (_, tpe) => toClassTag(tpe) }) + helper.saveAsTable(tableName, schema.map{ case (name, _) => name}) + (Array[ColumnDesc](), Array[Array[Object]]()) + case None => // return results + val newSchema = schema.map { case (name, tpe) => + new ColumnDesc(name, DataTypes.fromClassTag(toClassTag(tpe))) + } + val res = rdd.collect().map{p => p.map( _.asInstanceOf[Object] ).toArray} + (newSchema.toArray, res) + } + new ResultSet(sharkSchema, resultArr) + } + + private def getMethod(tgfName: String, methodName: String): Option[Method] = { + val tgfClazz = try { + Thread.currentThread().getContextClassLoader.loadClass(tgfName) + } catch { + case ex: ClassNotFoundException => + throw new QueryExecutionException("Couldn't find TGF class: " + tgfName) + } + + val methods = tgfClazz.getDeclaredMethods.filter(_.getName == methodName) + if (methods.isEmpty) None else Some(methods(0)) + } + + private def getSchema(tgfOutput: Object, tgfName: String): (RDD[Seq[_]], Seq[(String,String)]) = { + tgfOutput match { + case rddSchema: RDDSchema => + val schema = parser.parseAll(parser.schema, rddSchema.schema) + + (rddSchema.rdd, schema.get) + case rdd: RDD[Product] => + val applyMethod = getMethod(tgfName, "apply") + if (applyMethod == None) { + throw new QueryExecutionException("TGF lacking apply() method") + } + + val annotations = applyMethod.get.getAnnotation(classOf[Schema]) + if (annotations == null || annotations.spec() == null) { + throw new QueryExecutionException("No schema annotation found for TGF") + } + + // TODO: How can we compare schema with None? + val schema = parser.parseAll(parser.schema, annotations.spec()) + if (schema.isEmpty) { + throw new QueryExecutionException( + "Error parsing TGF schema annotation (@Schema(spec=...)") + } + + (rdd.map(_.productIterator.toList), schema.get) + case _ => + throw new QueryExecutionException("TGF output needs to be of type RDD or RDDSchema") + } + } + + private def reflectInvoke(tgfName: String, paramStrs: Seq[String], sc: SharkContext) = { + + val applyMethodOpt = getMethod(tgfName, "apply") + if (applyMethodOpt.isEmpty) { + throw new QueryExecutionException("TGF " + tgfName + " needs to implement apply()") + } + + val applyMethod = applyMethodOpt.get + + val typeNames: Seq[String] = applyMethod.getParameterTypes.toList.map(_.toString) + + val augParams = + if (!typeNames.isEmpty && typeNames.head.startsWith("class shark.SharkContext")) { + Seq("sc") ++ paramStrs + } else { + paramStrs + } + + if (augParams.length != typeNames.length) { + throw new QueryExecutionException("Expecting " + typeNames.length + + " parameters to " + tgfName + ", got " + augParams.length) + } + + val params = (augParams.toList zip typeNames.toList).map { + case (param: String, tpe: String) if tpe.startsWith("class shark.SharkContext") => + sc + case (param: String, tpe: String) if tpe.startsWith("class org.apache.spark.rdd.RDD") => + tableRdd(sc, param) + case (param: String, tpe: String) if tpe.startsWith("long") => + param.toLong + case (param: String, tpe: String) if tpe.startsWith("int") => + param.toInt + case (param: String, tpe: String) if tpe.startsWith("double") => + param.toDouble + case (param: String, tpe: String) if tpe.startsWith("float") => + param.toFloat + case (param: String, tpe: String) if tpe.startsWith("class java.lang.String") || + tpe.startsWith("class String") => + param.stripPrefix("\"").stripSuffix("\"") + case (param: String, tpe: String) => + throw new QueryExecutionException(s"Expected TGF parameter type: $tpe ($param)") + } + + applyMethod.invoke(null, params.asInstanceOf[List[Object]] : _*) + } + + private def toClassTag(tpe: String): ClassTag[_] = { + if (tpe == "boolean") classTag[Boolean] + else if (tpe == "tinyint") classTag[Byte] + else if (tpe == "smallint") classTag[Short] + else if (tpe == "int") classTag[Integer] + else if (tpe == "bigint") classTag[Long] + else if (tpe == "float") classTag[Float] + else if (tpe == "double") classTag[Double] + else if (tpe == "string") classTag[String] + else if (tpe == "timestamp") classTag[Timestamp] + else if (tpe == "date") classTag[Date] + else { + throw new QueryExecutionException("Unknown column type specified in schema (" + tpe + ")") + } + } + + def tableRdd(sc: SharkContext, tableName: String): RDD[_] = { + val rdd = sc.sql2rdd("SELECT * FROM " + tableName) + rdd.schema.size match { + case 2 => new TableRDD2(rdd, Seq()) + case 3 => new TableRDD3(rdd, Seq()) + case 4 => new TableRDD4(rdd, Seq()) + case 5 => new TableRDD5(rdd, Seq()) + case 6 => new TableRDD6(rdd, Seq()) + case 7 => new TableRDD7(rdd, Seq()) + case 8 => new TableRDD8(rdd, Seq()) + case 9 => new TableRDD9(rdd, Seq()) + case 10 => new TableRDD10(rdd, Seq()) + case 11 => new TableRDD11(rdd, Seq()) + case 12 => new TableRDD12(rdd, Seq()) + case 13 => new TableRDD13(rdd, Seq()) + case 14 => new TableRDD14(rdd, Seq()) + case 15 => new TableRDD15(rdd, Seq()) + case 16 => new TableRDD16(rdd, Seq()) + case 17 => new TableRDD17(rdd, Seq()) + case 18 => new TableRDD18(rdd, Seq()) + case 19 => new TableRDD19(rdd, Seq()) + case 20 => new TableRDD20(rdd, Seq()) + case 21 => new TableRDD21(rdd, Seq()) + case 22 => new TableRDD22(rdd, Seq()) + case _ => new TableSeqRDD(rdd) + } + } +} + +case class RDDSchema(rdd: RDD[Seq[_]], schema: String) + +private class TGFParser extends JavaTokenParsers { + + // Code to enable case-insensitive modifiers to strings, e.g. + // "Berkeley".ci will match "berkeley" + class MyString(str: String) { + def ci: Parser[String] = ("(?i)" + str).r + } + + implicit def stringToRichString(str: String): MyString = new MyString(str) + + def tgf: Parser[Any] = saveTgf | basicTgf + + /** + * @return Tuple2 containing a TGF method name and a List of parameters as strings + */ + def basicTgf: Parser[(String, List[String])] = { + ("GENERATE".ci ~> methodName) ~ (("(" ~> repsep(param, ",")) <~ ")") ^^ + { case id1 ~ x => (id1, x.asInstanceOf[List[String]]) } + } + + /** + * @return Tuple3 containing a table name, TGF method name and a List of parameters as strings + */ + def saveTgf: Parser[(String, String, List[String])] = { + (("GENERATE".ci ~> methodName) ~ (("(" ~> repsep(param, ",")) <~ ")")) ~ (("AS".ci) ~> + ident) ^^ { case id1 ~ x ~ id2 => (id2, id1, x.asInstanceOf[List[String]]) } + } + + def schema: Parser[Seq[(String,String)]] = repsep(nameType, ",") + + def nameType: Parser[(String,String)] = ident ~ ident ^^ { case name~tpe => Tuple2(name, tpe) } + + def param: Parser[Any] = stringLiteral | floatingPointNumber | decimalNumber | ident | + failure("Expected a string, number, or identifier as parameters in TGF") + + def methodName: Parser[String] = """[a-zA-Z_][\w\.]*""".r +} diff --git a/src/main/scala/shark/util/BloomFilter.scala b/src/main/scala/shark/util/BloomFilter.scala index 6a26b9e5..3a798d28 100644 --- a/src/main/scala/shark/util/BloomFilter.scala +++ b/src/main/scala/shark/util/BloomFilter.scala @@ -1,9 +1,27 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.util import java.util.BitSet import java.nio.charset.Charset -import scala.math._ -import com.google.common.primitives.Bytes + +import scala.math.{ceil, log} + import com.google.common.primitives.Ints import com.google.common.primitives.Longs @@ -16,13 +34,12 @@ import com.google.common.primitives.Longs * @param expectedSize is the number of elements to be contained in the filter. * @param numHashes is the number of hash functions. * @author Ram Sriharsha (harshars at yahoo-inc dot com) - * @date 07/07/2013 */ class BloomFilter(numBitsPerElement: Double, expectedSize: Int, numHashes: Int) - extends AnyRef with Serializable{ + extends AnyRef with Serializable { val SEED = System.getProperty("shark.bloomfilter.seed","1234567890").toInt - val bitSetSize = ceil(numBitsPerElement * expectedSize).toInt + val bitSetSize = math.ceil(numBitsPerElement * expectedSize).toInt val bitSet = new BitSet(bitSetSize) /** @@ -51,7 +68,7 @@ class BloomFilter(numBitsPerElement: Double, expectedSize: Int, numHashes: Int) * Optimization to allow reusing the same input buffer by specifying * the length of the buffer that contains the bytes to be hashed. * @param data is the bytes to be hashed. - * @param length is the length of the buffer to examine. + * @param len is the length of the buffer to examine. */ def add(data: Array[Byte], len: Int) { val hashes = hash(data, numHashes, len) @@ -96,9 +113,9 @@ class BloomFilter(numBitsPerElement: Double, expectedSize: Int, numHashes: Int) * Optimization to allow reusing the same input buffer by specifying * the length of the buffer that contains the bytes to be hashed. * @param data is the bytes to be hashed. - * @param length is the length of the buffer to examine. + * @param len is the length of the buffer to examine. * @return true with some false positive probability and false if the - * bytes is not contained in the bloom filter. + * bytes is not contained in the bloom filter. */ def contains(data: Array[Byte], len: Int): Boolean = { !hash(data,numHashes, len).exists { @@ -119,14 +136,17 @@ class BloomFilter(numBitsPerElement: Double, expectedSize: Int, numHashes: Int) MurmurHash3_x86_128.hash(data, SEED + i, len, results) a(i) = results(0).abs var j = i + 1 - if (j < n) + if (j < n) { a(j) = results(1).abs + } j += 1 - if (j < n) + if (j < n) { a(j) = results(2).abs + } j += 1 - if (j < n) + if (j < n) { a(j) = results(3).abs + } i += 1 } a @@ -139,4 +159,4 @@ object BloomFilter { def numHashes(fpp: Double, expectedSize: Int) = ceil(-(log(fpp) / log(2))).toInt -} \ No newline at end of file +} diff --git a/src/main/scala/shark/util/HiveUtils.scala b/src/main/scala/shark/util/HiveUtils.scala new file mode 100644 index 00000000..46465993 --- /dev/null +++ b/src/main/scala/shark/util/HiveUtils.scala @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.util + +import java.util.{Arrays => JArrays, ArrayList => JArrayList} +import java.util.{HashMap => JHashMap, HashSet => JHashSet} +import java.util.Properties + +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.Constants.META_TABLE_PARTITION_COLUMNS +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.UnionStructObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.ql.exec.DDLTask +import org.apache.hadoop.hive.ql.hooks.{ReadEntity, WriteEntity} +import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, DDLWork, DropTableDesc} + +import shark.api.{DataType, DataTypes} +import shark.memstore2.SharkTblProperties + + +private[shark] object HiveUtils { + + def getJavaPrimitiveObjectInspector(c: ClassTag[_]): PrimitiveObjectInspector = { + getJavaPrimitiveObjectInspector(DataTypes.fromClassTag(c)) + } + + def getJavaPrimitiveObjectInspector(t: DataType): PrimitiveObjectInspector = t match { + case DataTypes.BOOLEAN => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector + case DataTypes.TINYINT => PrimitiveObjectInspectorFactory.javaByteObjectInspector + case DataTypes.SMALLINT => PrimitiveObjectInspectorFactory.javaShortObjectInspector + case DataTypes.INT => PrimitiveObjectInspectorFactory.javaIntObjectInspector + case DataTypes.BIGINT => PrimitiveObjectInspectorFactory.javaLongObjectInspector + case DataTypes.FLOAT => PrimitiveObjectInspectorFactory.javaFloatObjectInspector + case DataTypes.DOUBLE => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + case DataTypes.TIMESTAMP => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector + case DataTypes.STRING => PrimitiveObjectInspectorFactory.javaStringObjectInspector + } + + /** + * Return a UnionStructObjectInspector that combines the StructObjectInspectors for the table + * schema and the partition columns, which are virtual in Hive. + */ + def makeUnionOIForPartitionedTable( + partProps: Properties, + partSerDe: Deserializer): UnionStructObjectInspector = { + val partCols = partProps.getProperty(META_TABLE_PARTITION_COLUMNS) + val partColNames = new JArrayList[String] + val partColObjectInspectors = new JArrayList[ObjectInspector] + partCols.trim().split("/").foreach { colName => + partColNames.add(colName) + partColObjectInspectors.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector) + } + + val partColObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( + partColNames, partColObjectInspectors) + val oiList = JArrays.asList( + partSerDe.getObjectInspector.asInstanceOf[StructObjectInspector], + partColObjectInspector.asInstanceOf[StructObjectInspector]) + // New oi is union of table + partition object inspectors + ObjectInspectorFactory.getUnionStructObjectInspector(oiList) + } + + /** + * Execute the create table DDL operation against Hive's metastore. + */ + def createTableInHive( + tableName: String, + columnNames: Seq[String], + columnTypes: Seq[ClassTag[_]], + hiveConf: HiveConf = new HiveConf): Boolean = { + val schema = columnNames.zip(columnTypes).map { case (colName, classTag) => + new FieldSchema(colName, DataTypes.fromClassTag(classTag).hiveName, "") + } + + // Setup the create table descriptor with necessary information. + val createTableDesc = new CreateTableDesc() + createTableDesc.setTableName(tableName) + createTableDesc.setCols(new JArrayList[FieldSchema](schema)) + createTableDesc.setTblProps( + SharkTblProperties.initializeWithDefaults(new JHashMap[String, String]())) + createTableDesc.setInputFormat("org.apache.hadoop.mapred.TextInputFormat") + createTableDesc.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat") + createTableDesc.setSerName(classOf[shark.memstore2.ColumnarSerDe].getName) + createTableDesc.setNumBuckets(-1) + + // Execute the create table against the Hive metastore. + val work = new DDLWork(new JHashSet[ReadEntity], new JHashSet[WriteEntity], createTableDesc) + val taskExecutionStatus = executeDDLTaskDirectly(work, hiveConf) + taskExecutionStatus == 0 + } + + def dropTableInHive(tableName: String, hiveConf: HiveConf = new HiveConf): Boolean = { + // Setup the drop table descriptor with necessary information. + val dropTblDesc = new DropTableDesc( + tableName, + false /* expectView */, + false /* ifExists */, + false /* stringPartitionColumns */) + + // Execute the drop table against the metastore. + val work = new DDLWork(new JHashSet[ReadEntity], new JHashSet[WriteEntity], dropTblDesc) + val taskExecutionStatus = executeDDLTaskDirectly(work, hiveConf) + taskExecutionStatus == 0 + } + + /** + * Creates a DDLTask from the DDLWork given, and directly calls DDLTask#execute(). Returns 0 if + * the create table command is executed successfully. + * This is safe to use for all DDL commands except for AlterTableTypes.ARCHIVE, which actually + * requires the DriverContext created in Hive Driver#execute(). + */ + def executeDDLTaskDirectly(ddlWork: DDLWork, hiveConf: HiveConf): Int = { + val task = new DDLTask() + task.initialize(hiveConf, null /* queryPlan */, null /* ctx: DriverContext */) + task.setWork(ddlWork) + task.execute(null /* driverContext */) + } +} diff --git a/src/main/scala/shark/util/MurmurHash3_x86_128.scala b/src/main/scala/shark/util/MurmurHash3_x86_128.scala index 5dcc6068..ff230ee5 100644 --- a/src/main/scala/shark/util/MurmurHash3_x86_128.scala +++ b/src/main/scala/shark/util/MurmurHash3_x86_128.scala @@ -1,7 +1,23 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.util import java.lang.Integer.{ rotateLeft => rotl } -import scala.math._ /** *

The MurmurHash3_x86_128(...) is a fast, non-cryptographic, 128-bit hash @@ -109,7 +125,7 @@ object MurmurHash3_x86_128 { * @param seed is the seed for the murmurhash algorithm. * @param length is the length of the buffer to use for hashing. * @param results is the output buffer to store the four ints that are returned, - * should have size atleast 4. + * should have size at least 4. */ @inline final def hash(data: Array[Byte], seed: Int, length: Int, results: Array[Int]): Unit = { @@ -177,18 +193,18 @@ object MurmurHash3_x86_128 { * @param rem is the remainder of the byte array to examine. */ @inline final def getInt(data: Array[Byte], index: Int, rem: Int): Int = { - rem match { + rem match { case 3 => data(index) << 24 | - (data(index + 1) & 0xFF) << 16 | - (data(index + 2) & 0xFF) << 8 + (data(index + 1) & 0xFF) << 16 | + (data(index + 2) & 0xFF) << 8 case 2 => data(index) << 24 | - (data(index + 1) & 0xFF) << 16 + (data(index + 1) & 0xFF) << 16 case 1 => data(index) << 24 case 0 => 0 case _ => data(index) << 24 | - (data(index + 1) & 0xFF) << 16 | - (data(index + 2) & 0xFF) << 8 | - (data(index + 3) & 0xFF) + (data(index + 1) & 0xFF) << 16 | + (data(index + 2) & 0xFF) << 8 | + (data(index + 3) & 0xFF) } } -} \ No newline at end of file +} diff --git a/src/main/scala/shark/util/QueryRewriteUtils.scala b/src/main/scala/shark/util/QueryRewriteUtils.scala new file mode 100644 index 00000000..8d44f8a8 --- /dev/null +++ b/src/main/scala/shark/util/QueryRewriteUtils.scala @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.util + +import org.apache.hadoop.hive.ql.parse.SemanticException + +import shark.memstore2.SharkTblProperties + + +object QueryRewriteUtils { + + def cacheToAlterTable(cmd: String): String = { + val cmdSplit = cmd.split(' ') + if (cmdSplit.size == 2) { + val tableName = cmdSplit(1) + "ALTER TABLE %s SET TBLPROPERTIES ('shark.cache' = 'true')".format(tableName) + } else { + throw new SemanticException( + s"CACHE accepts a single table name: 'CACHE

' (received command: '$cmd')") + } + } + + def uncacheToAlterTable(cmd: String): String = { + val cmdSplit = cmd.split(' ') + if (cmdSplit.size == 2) { + val tableName = cmdSplit(1) + "ALTER TABLE %s SET TBLPROPERTIES ('shark.cache' = 'false')".format(tableName) + } else { + throw new SemanticException( + s"UNCACHE accepts a single table name: 'UNCACHE
' (received command: '$cmd')") + } + } +} diff --git a/src/tachyon_disabled/scala/shark/tachyon/TachyonUtilImpl.scala b/src/tachyon_disabled/scala/shark/tachyon/TachyonUtilImpl.scala index 3f1d2eba..dbdf1ff6 100644 --- a/src/tachyon_disabled/scala/shark/tachyon/TachyonUtilImpl.scala +++ b/src/tachyon_disabled/scala/shark/tachyon/TachyonUtilImpl.scala @@ -22,35 +22,51 @@ import java.util.BitSet import org.apache.spark.rdd.RDD -import shark.memstore2.TablePartition - +import shark.memstore2.{Table, TablePartition, TablePartitionStats} class TachyonUtilImpl(val master: String, val warehousePath: String) extends TachyonUtil { + override def pushDownColumnPruning(rdd: RDD[_], columnUsed: BitSet): Boolean = false override def tachyonEnabled(): Boolean = false - override def tableExists(tableName: String): Boolean = { + override def tableExists(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean = { + throw new UnsupportedOperationException( + "This version of Shark is not compiled with Tachyon support.") + } + + override def dropTable(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean = { throw new UnsupportedOperationException( "This version of Shark is not compiled with Tachyon support.") } - override def dropTable(tableName: String): Boolean = { + override def createDirectory( + tableKey: String, + hivePartitionKeyOpt: Option[String]): Boolean = { throw new UnsupportedOperationException( "This version of Shark is not compiled with Tachyon support.") } - override def getTableMetadata(tableName: String): ByteBuffer = { + override def renameDirectory( + oldName: String, + newName: String): Boolean = { throw new UnsupportedOperationException( "This version of Shark is not compiled with Tachyon support.") } - override def createRDD(tableName: String): RDD[TablePartition] = { + override def createRDD( + tableKey: String, + hivePartitionKeyOpt: Option[String] + ): Seq[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = { throw new UnsupportedOperationException( "This version of Shark is not compiled with Tachyon support.") } - override def createTableWriter(tableName: String, numColumns: Int): TachyonTableWriter = { + override def createTableWriter( + tableKey: String, + hivePartitionKeyOpt: Option[String], + numColumns: Int + ): TachyonTableWriter = { throw new UnsupportedOperationException( "This version of Shark is not compiled with Tachyon support.") } diff --git a/src/tachyon_enabled/scala/shark/tachyon/TachyonUtilImpl.scala b/src/tachyon_enabled/scala/shark/tachyon/TachyonUtilImpl.scala index 32f27dee..8e4eab8d 100644 --- a/src/tachyon_enabled/scala/shark/tachyon/TachyonUtilImpl.scala +++ b/src/tachyon_enabled/scala/shark/tachyon/TachyonUtilImpl.scala @@ -19,67 +19,127 @@ package shark.tachyon import java.nio.ByteBuffer import java.util.BitSet +import java.util.concurrent.{ConcurrentHashMap => ConcurrentJavaHashMap} -import scala.collection.JavaConverters._ +import scala.collection.JavaConversions._ -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{EmptyRDD, RDD, UnionRDD} import tachyon.client.TachyonFS import tachyon.client.table.{RawTable, RawColumn} -import shark.SharkEnv -import shark.memstore2.TablePartition +import shark.{LogHelper, SharkEnv} +import shark.execution.serialization.JavaSerializer +import shark.memstore2.{MemoryMetadataManager, TablePartition, TablePartitionStats} /** * An abstraction for the Tachyon APIs. */ -class TachyonUtilImpl(val master: String, val warehousePath: String) extends TachyonUtil { +class TachyonUtilImpl( + val master: String, + val warehousePath: String) + extends TachyonUtil + with LogHelper { + + private val INSERT_FILE_PREFIX = "insert_" + + private val _fileNameMappings = new ConcurrentJavaHashMap[String, Int]() val client = if (master != null && master != "") TachyonFS.get(master) else null + private def getUniqueFilePath(parentDirectory: String): String = { + val parentDirectoryLower = parentDirectory.toLowerCase + val currentInsertNum = if (_fileNameMappings.containsKey(parentDirectoryLower)) { + _fileNameMappings.get(parentDirectoryLower) + } else { + 0 + } + var nextInsertNum = currentInsertNum + 1 + var filePath = parentDirectoryLower + "/" + INSERT_FILE_PREFIX + // Make sure there aren't file conflicts. This could occur if the directory was created in a + // previous Shark session. + while (client.exist(filePath + nextInsertNum)) { + nextInsertNum = nextInsertNum + 1 + } + _fileNameMappings.put(parentDirectoryLower, nextInsertNum) + filePath + nextInsertNum + } + if (master != null && warehousePath == null) { throw new TachyonException("TACHYON_MASTER is set. However, TACHYON_WAREHOUSE_PATH is not.") } - def getPath(tableName: String): String = warehousePath + "/" + tableName + private def getPath(tableKey: String, hivePartitionKeyOpt: Option[String]): String = { + val hivePartitionKey = if (hivePartitionKeyOpt.isDefined) { + "/" + hivePartitionKeyOpt.get + } else { + "" + } + warehousePath + "/" + tableKey + hivePartitionKey + } override def pushDownColumnPruning(rdd: RDD[_], columnUsed: BitSet): Boolean = { - if (rdd.isInstanceOf[TachyonTableRDD]) { + val isTachyonTableRdd = rdd.isInstanceOf[TachyonTableRDD] + if (isTachyonTableRdd) { rdd.asInstanceOf[TachyonTableRDD].setColumnUsed(columnUsed) - true - } else { - false } + isTachyonTableRdd } + override def tachyonEnabled(): Boolean = + (master != null && warehousePath != null && client.isConnected) - override def tachyonEnabled(): Boolean = (master != null && warehousePath != null) - - override def tableExists(tableName: String): Boolean = { - client.exist(getPath(tableName)) + override def tableExists(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean = { + client.exist(getPath(tableKey, hivePartitionKeyOpt)) } - override def dropTable(tableName: String): Boolean = { + override def dropTable(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean = { // The second parameter (true) means recursive deletion. - client.delete(getPath(tableName), true) + client.delete(getPath(tableKey, hivePartitionKeyOpt), true) } - override def getTableMetadata(tableName: String): ByteBuffer = { - if (!tableExists(tableName)) { - throw new TachyonException("Table " + tableName + " does not exist in Tachyon") - } - client.getRawTable(getPath(tableName)).getMetadata() + override def createDirectory( + tableKey: String, + hivePartitionKeyOpt: Option[String]): Boolean = { + client.mkdir(getPath(tableKey, hivePartitionKeyOpt)) + } + + override def renameDirectory( + oldTableKey: String, + newTableKey: String): Boolean = { + val oldPath = getPath(oldTableKey, hivePartitionKeyOpt = None) + val newPath = getPath(newTableKey, hivePartitionKeyOpt = None) + client.rename(oldPath, newPath) } - override def createRDD(tableName: String): RDD[TablePartition] = { - new TachyonTableRDD(getPath(tableName), SharkEnv.sc) + override def createRDD( + tableKey: String, + hivePartitionKeyOpt: Option[String] + ): Seq[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = { + // Create a TachyonTableRDD for each raw table file in the directory. + val tableDirectory = getPath(tableKey, hivePartitionKeyOpt) + val files = client.ls(tableDirectory, false /* recursive */) + // The first path is just "{tableDirectory}/", so ignore it. + val rawTableFiles = files.subList(1, files.size) + val tableRDDsAndStats = rawTableFiles.map { filePath => + val serializedMetadata = client.getRawTable(client.getFileId(filePath)).getMetadata + val indexToStats = JavaSerializer.deserialize[collection.Map[Int, TablePartitionStats]]( + serializedMetadata.array()) + (new TachyonTableRDD(filePath, SharkEnv.sc), indexToStats) + } + tableRDDsAndStats } - override def createTableWriter(tableName: String, numColumns: Int): TachyonTableWriter = { + override def createTableWriter( + tableKey: String, + hivePartitionKeyOpt: Option[String], + numColumns: Int): TachyonTableWriter = { if (!client.exist(warehousePath)) { client.mkdir(warehousePath) } - new TachyonTableWriterImpl(getPath(tableName), numColumns) + val parentDirectory = getPath(tableKey, hivePartitionKeyOpt) + val filePath = getUniqueFilePath(parentDirectory) + new TachyonTableWriterImpl(filePath, numColumns) } } diff --git a/src/test/java/shark/JavaAPISuite.java b/src/test/java/shark/JavaAPISuite.java index 01f6fe58..49b0d2e8 100644 --- a/src/test/java/shark/JavaAPISuite.java +++ b/src/test/java/shark/JavaAPISuite.java @@ -48,13 +48,9 @@ public static void oneTimeSetUp() { // Intentionally leaving this here since SBT doesn't seem to display junit tests well ... System.out.println("running JavaAPISuite ================================================"); - sc = SharkEnv.initWithJavaSharkContext("JavaAPISuite", "local"); - - sc.sql("set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=" + - METASTORE_PATH + ";create=true"); - sc.sql("set hive.metastore.warehouse.dir=" + WAREHOUSE_PATH); - - sc.sql("set shark.test.data.path=" + TestUtils$.MODULE$.dataFilePath()); + // Check if the SharkEnv's SharkContext has already been initialized. If so, use that to + // instantiate a JavaSharkContext. + sc = SharkRunner.initWithJava(); // test sc.sql("drop table if exists test_java"); diff --git a/src/test/scala/shark/ColumnStatsSQLSuite.scala b/src/test/scala/shark/ColumnStatsSQLSuite.scala new file mode 100644 index 00000000..f0aa5931 --- /dev/null +++ b/src/test/scala/shark/ColumnStatsSQLSuite.scala @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark + +import org.apache.hadoop.io.BytesWritable + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.hadoop.hive.metastore.MetaStoreUtils.DEFAULT_DATABASE_NAME + +import org.apache.spark.rdd.RDD + +import shark.memstore2.MemoryMetadataManager + + +class ColumnStatsSQLSuite extends FunSuite with BeforeAndAfterAll { + + val sc: SharkContext = SharkRunner.init() + val sharkMetastore = SharkEnv.memoryMetadataManager + + // import expectSql() shortcut methods + import shark.SharkRunner._ + + override def beforeAll() { + sc.runSql("drop table if exists srcpart_cached") + sc.runSql("create table srcpart_cached(key int, val string) partitioned by (keypart int)") + sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/kv1.txt' + into table srcpart_cached partition (keypart = 1)""") + } + + override def afterAll() { + sc.runSql("drop table if exists srcpart_cached") + } + + test("Hive partition stats are tracked") { + val tableOpt = sharkMetastore.getPartitionedTable(DEFAULT_DATABASE_NAME, "srcpart_cached") + assert(tableOpt.isDefined) + val partitionToStatsOpt = tableOpt.get.getStats("keypart=1") + assert(partitionToStatsOpt.isDefined) + val partitionToStats = partitionToStatsOpt.get + // The 'kv1.txt' file loaded into 'keypart=1' in beforeAll() has 2 partitions. + assert(partitionToStats.size == 2) + } + + test("Hive partition stats are tracked after LOADs and INSERTs") { + // Load more data into srcpart_cached + sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/kv1.txt' + into table srcpart_cached partition (keypart = 1)""") + val tableOpt = sharkMetastore.getPartitionedTable(DEFAULT_DATABASE_NAME, "srcpart_cached") + assert(tableOpt.isDefined) + var partitionToStatsOpt = tableOpt.get.getStats("keypart=1") + assert(partitionToStatsOpt.isDefined) + var partitionToStats = partitionToStatsOpt.get + // The 'kv1.txt' file loaded into 'keypart=1' has 2 partitions. We've loaded it twice at this + // point. + assert(partitionToStats.size == 4) + + // Append using INSERT command + sc.runSql("insert into table srcpart_cached partition(keypart = 1) select * from test") + partitionToStatsOpt = tableOpt.get.getStats("keypart=1") + assert(partitionToStatsOpt.isDefined) + partitionToStats = partitionToStatsOpt.get + assert(partitionToStats.size == 6) + + // INSERT OVERWRITE should overrwritie old table stats. This also restores srcpart_cached + // to contents contained before this test. + sc.runSql("""insert overwrite table srcpart_cached partition(keypart = 1) + select * from test""") + partitionToStatsOpt = tableOpt.get.getStats("keypart=1") + assert(partitionToStatsOpt.isDefined) + partitionToStats = partitionToStatsOpt.get + assert(partitionToStats.size == 2) + } + + ////////////////////////////////////////////////////////////////////////////// + // End-to-end sanity checks + ////////////////////////////////////////////////////////////////////////////// + test("column pruning filters") { + expectSql("select count(*) from test_cached where key > -1", "500") + } + + test("column pruning group by") { + expectSql("select key, count(*) from test_cached group by key order by key limit 1", "0\t3") + } + + test("column pruning group by with single filter") { + expectSql("select key, count(*) from test_cached where val='val_484' group by key", "484\t1") + } + + test("column pruning aggregate function") { + expectSql("select val, sum(key) from test_cached group by val order by val desc limit 1", + "val_98\t196") + } + + test("column pruning filters for a Hive partition") { + expectSql("select count(*) from srcpart_cached where key > -1", "500") + expectSql("select count(*) from srcpart_cached where key > -1 and keypart = 1", "500") + } + + test("column pruning group by for a Hive partition") { + expectSql("select key, count(*) from srcpart_cached group by key order by key limit 1", "0\t3") + } + + test("column pruning group by with single filter for a Hive partition") { + expectSql("select key, count(*) from srcpart_cached where val='val_484' group by key", "484\t1") + } + + test("column pruning aggregate function for a Hive partition") { + expectSql("select val, sum(key) from srcpart_cached group by val order by val desc limit 1", + "val_98\t196") + } + +} diff --git a/src/test/scala/shark/SQLSuite.scala b/src/test/scala/shark/SQLSuite.scala index 9751bcb3..746e3c18 100644 --- a/src/test/scala/shark/SQLSuite.scala +++ b/src/test/scala/shark/SQLSuite.scala @@ -17,87 +17,91 @@ package shark -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import shark.api.QueryExecutionException - - -class SQLSuite extends FunSuite with BeforeAndAfterAll { - - val WAREHOUSE_PATH = TestUtils.getWarehousePath() - val METASTORE_PATH = TestUtils.getMetastorePath() - val MASTER = "local" +import scala.collection.JavaConversions._ - var sc: SharkContext = _ - - override def beforeAll() { - sc = SharkEnv.initWithSharkContext("shark-sql-suite-testing", MASTER) +import org.scalatest.FunSuite - sc.runSql("set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=" + - METASTORE_PATH + ";create=true") - sc.runSql("set hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) +import org.apache.hadoop.hive.metastore.MetaStoreUtils.DEFAULT_DATABASE_NAME +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.UnionRDD +import org.apache.spark.storage.StorageLevel - sc.runSql("set shark.test.data.path=" + TestUtils.dataFilePath) +import shark.api.QueryExecutionException +import shark.memstore2.{CacheType, MemoryMetadataManager, PartitionedMemoryTable} +import shark.tgf.{RDDSchema, Schema} +// import expectSql() shortcut methods +import shark.SharkRunner._ - // test - sc.runSql("drop table if exists test") - sc.runSql("CREATE TABLE test (key INT, val STRING)") - sc.runSql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/kv1.txt' INTO TABLE test") - sc.runSql("drop table if exists test_cached") - sc.runSql("CREATE TABLE test_cached AS SELECT * FROM test") - // test_null - sc.runSql("drop table if exists test_null") - sc.runSql("CREATE TABLE test_null (key INT, val STRING)") - sc.runSql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/kv3.txt' INTO TABLE test_null") - sc.runSql("drop table if exists test_null_cached") - sc.runSql("CREATE TABLE test_null_cached AS SELECT * FROM test_null") +class SQLSuite extends FunSuite { - // clicks - sc.runSql("drop table if exists clicks") - sc.runSql("""create table clicks (id int, click int) - row format delimited fields terminated by '\t'""") - sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/clicks.txt' - OVERWRITE INTO TABLE clicks""") - sc.runSql("drop table if exists clicks_cached") - sc.runSql("create table clicks_cached as select * from clicks") + val DEFAULT_DB_NAME = DEFAULT_DATABASE_NAME + val KV1_TXT_PATH = "${hiveconf:shark.test.data.path}/kv1.txt" - // users - sc.runSql("drop table if exists users") - sc.runSql("""create table users (id int, name string) - row format delimited fields terminated by '\t'""") - sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/users.txt' - OVERWRITE INTO TABLE users""") - sc.runSql("drop table if exists users_cached") - sc.runSql("create table users_cached as select * from users") + var sc: SharkContext = SharkRunner.init() + var sharkMetastore: MemoryMetadataManager = SharkEnv.memoryMetadataManager - // test1 - sc.sql("drop table if exists test1") - sc.sql("""CREATE TABLE test1 (id INT, test1val ARRAY) - row format delimited fields terminated by '\t'""") - sc.sql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/test1.txt' INTO TABLE test1") - sc.sql("drop table if exists test1_cached") - sc.sql("CREATE TABLE test1_cached AS SELECT * FROM test1") + private def createCachedPartitionedTable( + tableName: String, + numPartitionsToCreate: Int, + maxCacheSize: Int = 10, + cachePolicyClassName: String = "shark.memstore2.LRUCachePolicy" + ): PartitionedMemoryTable = { + sc.runSql("drop table if exists %s".format(tableName)) + sc.runSql(""" + create table %s(key int, value string) + partitioned by (keypart int) + tblproperties('shark.cache' = 'true', + 'shark.cache.policy.maxSize' = '%d', + 'shark.cache.policy' = '%s') + """.format( + tableName, + maxCacheSize, + cachePolicyClassName)) + var partitionNum = 1 + while (partitionNum <= numPartitionsToCreate) { + sc.runSql("""insert into table %s partition(keypart = %d) + select * from test_cached""".format(tableName, partitionNum)) + partitionNum += 1 + } + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + val partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable( + DEFAULT_DB_NAME, tableName).get + partitionedTable } - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") + def isFlattenedUnionRDD(unionRDD: UnionRDD[_]) = { + unionRDD.rdds.find(_.isInstanceOf[UnionRDD[_]]).isEmpty } - private def expectSql(sql: String, expectedResults: Array[String], sort: Boolean = true) { - val sharkResults: Array[String] = sc.runSql(sql).results.map(_.mkString("\t")).toArray - val results = if (sort) sharkResults.sortWith(_ < _) else sharkResults - val expected = if (sort) expectedResults.sortWith(_ < _) else expectedResults - assert(results.corresponds(expected)(_.equals(_)), - "In SQL: " + sql + "\n" + - "Expected: " + expected.mkString("\n") + "; got " + results.mkString("\n")) - } + // Takes a sum over the table's 'key' column, for both the cached contents and the copy on disk. + def expectUnifiedKVTable( + cachedTableName: String, + partSpecOpt: Option[Map[String, String]] = None) { + // Check that the table is in memory and is a unified view. + val sharkTableOpt = sharkMetastore.getTable(DEFAULT_DB_NAME, cachedTableName) + assert(sharkTableOpt.isDefined, "Table %s cannot be found in the Shark metastore") + assert(sharkTableOpt.get.cacheMode == CacheType.MEMORY, + "'shark.cache' field for table %s is not CacheType.MEMORY") - // A shortcut for single row results. - private def expectSql(sql: String, expectedResult: String) { - expectSql(sql, Array(expectedResult)) + // Load a non-cached copy of the table into memory. + val cacheSum = sc.sql("select sum(key) from %s".format(cachedTableName))(0) + val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, cachedTableName) + val location = partSpecOpt match { + case Some(partSpec) => { + val partition = Hive.get().getPartition(hiveTable, partSpec, false /* forceCreate */) + partition.getDataLocation.toString + } + case None => hiveTable.getDataLocation.toString + } + // Create a table with contents loaded from the table's data directory. + val diskTableName = "%s_disk_copy".format(cachedTableName) + sc.sql("drop table if exists %s".format(diskTableName)) + sc.sql("create table %s (key int, value string)".format(diskTableName)) + sc.sql("load data local inpath '%s' into table %s".format(location, diskTableName)) + val diskSum = sc.sql("select sum(key) from %s".format(diskTableName))(0) + assert(diskSum == cacheSum, "Sum of keys from cached and disk contents differ") } ////////////////////////////////////////////////////////////////////////////// @@ -166,26 +170,6 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sort = false) } - ////////////////////////////////////////////////////////////////////////////// - // column pruning - ////////////////////////////////////////////////////////////////////////////// - test("column pruning filters") { - expectSql("select count(*) from test_cached where key > -1", "500") - } - - test("column pruning group by") { - expectSql("select key, count(*) from test_cached group by key order by key limit 1", "0\t3") - } - - test("column pruning group by with single filter") { - expectSql("select key, count(*) from test_cached where val='val_484' group by key", "484\t1") - } - - test("column pruning aggregate function") { - expectSql("select val, sum(key) from test_cached group by val order by val desc limit 1", - "val_98\t196") - } - ////////////////////////////////////////////////////////////////////////////// // join ////////////////////////////////////////////////////////////////////////////// @@ -221,6 +205,42 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { ////////////////////////////////////////////////////////////////////////////// // cache DDL ////////////////////////////////////////////////////////////////////////////// + test("Use regular CREATE TABLE and '_cached' suffix to create cached table") { + sc.runSql("drop table if exists empty_table_cached") + sc.runSql("create table empty_table_cached(key string, value string)") + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "empty_table_cached")) + assert(!SharkEnv.memoryMetadataManager.isHivePartitioned(DEFAULT_DB_NAME, "empty_table_cached")) + } + + test("Use regular CREATE TABLE and table properties to create cached table") { + sc.runSql("drop table if exists empty_table_cached_tbl_props") + sc.runSql("""create table empty_table_cached_tbl_props(key string, value string) + TBLPROPERTIES('shark.cache' = 'true')""") + assert(SharkEnv.memoryMetadataManager.containsTable( + DEFAULT_DB_NAME, "empty_table_cached_tbl_props")) + assert(!SharkEnv.memoryMetadataManager.isHivePartitioned( + DEFAULT_DB_NAME, "empty_table_cached_tbl_props")) + } + + test("Insert into empty cached table") { + sc.runSql("drop table if exists new_table_cached") + sc.runSql("create table new_table_cached(key string, value string)") + sc.runSql("insert into table new_table_cached select * from test where key > -1 limit 499") + expectSql("select count(*) from new_table_cached", "499") + } + + test("rename cached table") { + sc.runSql("drop table if exists test_oldname_cached") + sc.runSql("drop table if exists test_rename") + sc.runSql("create table test_oldname_cached as select * from test") + sc.runSql("alter table test_oldname_cached rename to test_rename") + + assert(!SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "test_oldname_cached")) + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "test_rename")) + + expectSql("select count(*) from test_rename", "500") + } + test("insert into cached tables") { sc.runSql("drop table if exists test1_cached") sc.runSql("create table test1_cached as select * from test") @@ -249,22 +269,24 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { } } - ignore("drop partition") { - sc.runSql("create table foo_cached(key int, val string) partitioned by (dt string)") - sc.runSql("insert overwrite table foo_cached partition(dt='100') select * from test") - expectSql("select count(*) from foo_cached", "500") - sc.runSql("alter table foo_cached drop partition(dt='100')") - expectSql("select count(*) from foo_cached", "0") - } - - test("create cached table with table properties") { + test("create cached table with 'shark.cache' flag in table properties") { sc.runSql("drop table if exists ctas_tbl_props") sc.runSql("""create table ctas_tbl_props TBLPROPERTIES ('shark.cache'='true') as select * from test""") - assert(SharkEnv.memoryMetadataManager.contains("ctas_tbl_props")) + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "ctas_tbl_props")) expectSql("select * from ctas_tbl_props where key=407", "407\tval_407") } + test("default to Hive table creation when 'shark.cache' flag is false in table properties") { + sc.runSql("drop table if exists ctas_tbl_props_should_not_be_cached") + sc.runSql(""" + CREATE TABLE ctas_tbl_props_result_should_not_be_cached + TBLPROPERTIES ('shark.cache'='false') + AS select * from test""") + assert(!SharkEnv.memoryMetadataManager.containsTable( + DEFAULT_DB_NAME, "ctas_tbl_props_should_not_be_cached")) + } + test("cached tables with complex types") { sc.runSql("drop table if exists test_complex_types") sc.runSql("drop table if exists test_complex_types_cached") @@ -286,7 +308,8 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { assert(sc.sql("select d from test_complex_types_cached where a = 'a0'").head === """{"d01":["d011","d012"],"d02":["d021","d022"]}""") - assert(SharkEnv.memoryMetadataManager.contains("test_complex_types_cached")) + assert(SharkEnv.memoryMetadataManager.containsTable( + DEFAULT_DB_NAME, "test_complex_types_cached")) } test("disable caching by default") { @@ -294,7 +317,8 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.runSql("drop table if exists should_not_be_cached") sc.runSql("create table should_not_be_cached as select * from test") expectSql("select key from should_not_be_cached where key = 407", "407") - assert(!SharkEnv.memoryMetadataManager.contains("should_not_be_cached")) + assert(!SharkEnv.memoryMetadataManager.containsTable( + DEFAULT_DB_NAME, "should_not_be_cached")) sc.runSql("set shark.cache.flag.checkTableName=true") } @@ -303,7 +327,7 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.runSql("""create table sharkTest5Cached TBLPROPERTIES ("shark.cache" = "true") as select * from test""") expectSql("select val from sharktest5Cached where key = 407", "val_407") - assert(SharkEnv.memoryMetadataManager.contains("sharkTest5Cached")) + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "sharkTest5Cached")) } test("dropping cached tables should clean up RDDs") { @@ -311,7 +335,325 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.runSql("""create table sharkTest5Cached TBLPROPERTIES ("shark.cache" = "true") as select * from test""") sc.runSql("drop table sharkTest5Cached") - assert(!SharkEnv.memoryMetadataManager.contains("sharkTest5Cached")) + assert(!SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "sharkTest5Cached")) + } + + ////////////////////////////////////////////////////////////////////////////// + // Caching Hive-partititioned tables + // Note: references to 'partition' for this section refer to a Hive-partition. + ////////////////////////////////////////////////////////////////////////////// + test("Use regular CREATE TABLE and '_cached' suffix to create cached, partitioned table") { + sc.runSql("drop table if exists empty_part_table_cached") + sc.runSql("""create table empty_part_table_cached(key int, value string) + partitioned by (keypart int)""") + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "empty_part_table_cached")) + assert(SharkEnv.memoryMetadataManager.isHivePartitioned( + DEFAULT_DB_NAME, "empty_part_table_cached")) + } + + test("Use regular CREATE TABLE and table properties to create cached, partitioned table") { + sc.runSql("drop table if exists empty_part_table_cached_tbl_props") + sc.runSql("""create table empty_part_table_cached_tbl_props(key int, value string) + partitioned by (keypart int) tblproperties('shark.cache' = 'true')""") + assert(SharkEnv.memoryMetadataManager.containsTable( + DEFAULT_DB_NAME, "empty_part_table_cached_tbl_props")) + assert(SharkEnv.memoryMetadataManager.isHivePartitioned( + DEFAULT_DB_NAME, "empty_part_table_cached_tbl_props")) + } + + test("alter cached table by adding a new partition") { + sc.runSql("drop table if exists alter_part_cached") + sc.runSql("""create table alter_part_cached(key int, value string) + partitioned by (keypart int)""") + sc.runSql("""alter table alter_part_cached add partition(keypart = 1)""") + val tableName = "alter_part_cached" + val partitionColumn = "keypart=1" + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + val partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable( + DEFAULT_DB_NAME, tableName).get + assert(partitionedTable.containsPartition(partitionColumn)) + } + + test("alter cached table by dropping a partition") { + sc.runSql("drop table if exists alter_drop_part_cached") + sc.runSql("""create table alter_drop_part_cached(key int, value string) + partitioned by (keypart int)""") + sc.runSql("""alter table alter_drop_part_cached add partition(keypart = 1)""") + val tableName = "alter_drop_part_cached" + val partitionColumn = "keypart=1" + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + val partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable( + DEFAULT_DB_NAME, tableName).get + assert(partitionedTable.containsPartition(partitionColumn)) + sc.runSql("""alter table alter_drop_part_cached drop partition(keypart = 1)""") + assert(!partitionedTable.containsPartition(partitionColumn)) + } + + test("insert into a partition of a cached table") { + val tableName = "insert_part_cached" + val partitionedTable = createCachedPartitionedTable( + tableName, + 1 /* numPartitionsToCreate */) + expectSql("select value from insert_part_cached where key = 407 and keypart = 1", "val_407") + + } + + test("insert overwrite a partition of a cached table") { + val tableName = "insert_over_part_cached" + val partitionedTable = createCachedPartitionedTable( + tableName, + 1 /* numPartitionsToCreate */) + expectSql("""select value from insert_over_part_cached + where key = 407 and keypart = 1""", "val_407") + sc.runSql("""insert overwrite table insert_over_part_cached partition(keypart = 1) + select key, -1 from test""") + expectSql("select value from insert_over_part_cached where key = 407 and keypart = 1", "-1") + } + + test("scan cached, partitioned table that's empty") { + sc.runSql("drop table if exists empty_part_table_cached") + sc.runSql("""create table empty_part_table_cached(key int, value string) + partitioned by (keypart int)""") + expectSql("select count(*) from empty_part_table_cached", "0") + } + + test("scan cached, partitioned table that has a single partition") { + val tableName = "scan_single_part_cached" + val partitionedTable = createCachedPartitionedTable( + tableName, + 1 /* numPartitionsToCreate */) + expectSql("select * from scan_single_part_cached where key = 407", "407\tval_407\t1") + } + + test("scan cached, partitioned table that has multiple partitions") { + val tableName = "scan_mult_part_cached" + val partitionedTable = createCachedPartitionedTable( + tableName, + 3 /* numPartitionsToCreate */) + expectSql("select * from scan_mult_part_cached where key = 407 order by keypart", + Array("407\tval_407\t1", "407\tval_407\t2", "407\tval_407\t3")) + } + + test("drop/unpersist cached, partitioned table that has multiple partitions") { + val tableName = "drop_mult_part_cached" + val partitionedTable = createCachedPartitionedTable( + tableName, + 3 /* numPartitionsToCreate */) + val keypart1RDD = partitionedTable.getPartition("keypart=1") + val keypart2RDD = partitionedTable.getPartition("keypart=2") + val keypart3RDD = partitionedTable.getPartition("keypart=3") + sc.runSql("drop table drop_mult_part_cached ") + assert(!SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + // All RDDs should have been unpersisted. + assert(keypart1RDD.get.getStorageLevel == StorageLevel.NONE) + assert(keypart2RDD.get.getStorageLevel == StorageLevel.NONE) + assert(keypart3RDD.get.getStorageLevel == StorageLevel.NONE) + } + + test("drop cached partition represented by a UnionRDD (i.e., the result of multiple inserts)") { + val tableName = "drop_union_part_cached" + val partitionedTable = createCachedPartitionedTable( + tableName, + 1 /* numPartitionsToCreate */) + sc.runSql("insert into table drop_union_part_cached partition(keypart = 1) select * from test") + sc.runSql("insert into table drop_union_part_cached partition(keypart = 1) select * from test") + sc.runSql("insert into table drop_union_part_cached partition(keypart = 1) select * from test") + val keypart1RDD = partitionedTable.getPartition("keypart=1") + sc.runSql("drop table drop_union_part_cached") + assert(!SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + // All RDDs should have been unpersisted. + assert(keypart1RDD.get.getStorageLevel == StorageLevel.NONE) + } + + ////////////////////////////////////////////////////////////////////////////// + // RDD(partition) eviction policy for cached Hive-partititioned tables + ////////////////////////////////////////////////////////////////////////////// + + test("shark.memstore2.CacheAllPolicy is the default policy") { + val tableName = "default_policy_cached" + sc.runSql("""create table default_policy_cached(key int, value string) + partitioned by (keypart int)""") + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + val partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable( + DEFAULT_DB_NAME, tableName).get + val cachePolicy = partitionedTable.cachePolicy + assert(cachePolicy.isInstanceOf[shark.memstore2.CacheAllPolicy[_, _]]) + } + + test("LRU: RDDs are not evicted if the cache isn't full.") { + val tableName = "evict_partitions_maxSize" + val partitionedTable = createCachedPartitionedTable( + tableName, + 2 /* numPartitionsToCreate */, + 3 /* maxCacheSize */, + "shark.memstore2.LRUCachePolicy") + val keypart1RDD = partitionedTable.keyToPartitions.get("keypart=1") + assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.MEMORY_AND_DISK) + } + + test("LRU: RDDs are evicted when the max size is reached.") { + val tableName = "evict_partitions_maxSize" + val partitionedTable = createCachedPartitionedTable( + tableName, + 3 /* numPartitionsToCreate */, + 3 /* maxCacheSize */, + "shark.memstore2.LRUCachePolicy") + val keypart1RDD = partitionedTable.keyToPartitions.get("keypart=1") + assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.MEMORY_AND_DISK) + sc.runSql("""insert into table evict_partitions_maxSize partition(keypart = 4) + select * from test""") + assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.NONE) + } + + test("LRU: RDD eviction accounts for partition scans - a cache.get()") { + val tableName = "evict_partitions_with_get" + val partitionedTable = createCachedPartitionedTable( + tableName, + 3 /* numPartitionsToCreate */, + 3 /* maxCacheSize */, + "shark.memstore2.LRUCachePolicy") + val keypart1RDD = partitionedTable.keyToPartitions.get("keypart=1") + val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2") + assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.MEMORY_AND_DISK) + assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.MEMORY_AND_DISK) + sc.runSql("select count(1) from evict_partitions_with_get where keypart = 1") + sc.runSql("""insert into table evict_partitions_with_get partition(keypart = 4) + select * from test""") + assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.MEMORY_AND_DISK) + + assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.NONE) + } + + test("LRU: RDD eviction accounts for INSERT INTO - a cache.get().") { + val tableName = "evict_partitions_insert_into" + val partitionedTable = createCachedPartitionedTable( + tableName, + 3 /* numPartitionsToCreate */, + 3 /* maxCacheSize */, + "shark.memstore2.LRUCachePolicy") + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + val oldKeypart1RDD = partitionedTable.keyToPartitions.get("keypart=1") + val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2") + assert(TestUtils.getStorageLevelOfRDD(oldKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK) + assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.MEMORY_AND_DISK) + sc.runSql("""insert into table evict_partitions_insert_into partition(keypart = 1) + select * from test""") + sc.runSql("""insert into table evict_partitions_insert_into partition(keypart = 4) + select * from test""") + assert(TestUtils.getStorageLevelOfRDD(oldKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK) + val newKeypart1RDD = partitionedTable.keyToPartitions.get("keypart=1") + assert(TestUtils.getStorageLevelOfRDD(newKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK) + + val keypart2StorageLevel = TestUtils.getStorageLevelOfRDD(keypart2RDD.get) + assert(keypart2StorageLevel == StorageLevel.NONE) + } + + test("LRU: RDD eviction accounts for INSERT OVERWRITE - a cache.put()") { + val tableName = "evict_partitions_insert_overwrite" + val partitionedTable = createCachedPartitionedTable( + tableName, + 3 /* numPartitionsToCreate */, + 3 /* maxCacheSize */, + "shark.memstore2.LRUCachePolicy") + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + val oldKeypart1RDD = partitionedTable.keyToPartitions.get("keypart=1") + val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2") + assert(TestUtils.getStorageLevelOfRDD(oldKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK) + assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.MEMORY_AND_DISK) + sc.runSql("""insert overwrite table evict_partitions_insert_overwrite partition(keypart = 1) + select * from test""") + sc.runSql("""insert into table evict_partitions_insert_overwrite partition(keypart = 4) + select * from test""") + assert(TestUtils.getStorageLevelOfRDD(oldKeypart1RDD.get) == StorageLevel.NONE) + val newKeypart1RDD = partitionedTable.keyToPartitions.get("keypart=1") + assert(TestUtils.getStorageLevelOfRDD(newKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK) + + val keypart2StorageLevel = TestUtils.getStorageLevelOfRDD(keypart2RDD.get) + assert(keypart2StorageLevel == StorageLevel.NONE) + } + + test("LRU: RDD eviction accounts for ALTER TABLE DROP PARTITION - a cache.remove()") { + val tableName = "evict_partitions_removals" + val partitionedTable = createCachedPartitionedTable( + tableName, + 3 /* numPartitionsToCreate */, + 3 /* maxCacheSize */, + "shark.memstore2.LRUCachePolicy") + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + sc.runSql("alter table evict_partitions_removals drop partition(keypart = 1)") + sc.runSql("""insert into table evict_partitions_removals partition(keypart = 4) + select * from test""") + sc.runSql("""insert into table evict_partitions_removals partition(keypart = 5) + select * from test""") + val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2") + assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.NONE) + } + + test("LRU: get() reloads an RDD previously unpersist()'d.") { + val tableName = "reload_evicted_partition" + val partitionedTable = createCachedPartitionedTable( + tableName, + 3 /* numPartitionsToCreate */, + 3 /* maxCacheSize */, + "shark.memstore2.LRUCachePolicy") + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName)) + val keypart1RDD = partitionedTable.keyToPartitions.get("keypart=1") + val lvl = TestUtils.getStorageLevelOfRDD(keypart1RDD.get) + assert(lvl == StorageLevel.MEMORY_AND_DISK, "got: " + lvl) + sc.runSql("""insert into table reload_evicted_partition partition(keypart = 4) + select * from test""") + assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.NONE) + + // Scanning partition (keypart = 1) should reload the corresponding RDD into the cache, and + // cause eviction of the RDD for partition (keypart = 2). + sc.runSql("select count(1) from reload_evicted_partition where keypart = 1") + assert(keypart1RDD.get.getStorageLevel == StorageLevel.MEMORY_AND_DISK) + val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2") + val keypart2StorageLevel = TestUtils.getStorageLevelOfRDD(keypart2RDD.get) + assert(keypart2StorageLevel == StorageLevel.NONE, + "StorageLevel for partition(keypart=2) should be NONE, but got: " + keypart2StorageLevel) + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Prevent nested UnionRDDs - those should be "flattened" in MemoryStoreSinkOperator. + /////////////////////////////////////////////////////////////////////////////////////// + + test("flatten UnionRDDs") { + sc.sql("create table flat_cached as select * from test_cached") + sc.sql("insert into table flat_cached select * from test") + val tableName = "flat_cached" + var memoryTable = SharkEnv.memoryMetadataManager.getMemoryTable(DEFAULT_DB_NAME, tableName).get + var unionRDD = memoryTable.getRDD.get.asInstanceOf[UnionRDD[_]] + val numParentRDDs = unionRDD.rdds.size + assert(isFlattenedUnionRDD(unionRDD)) + + // Insert another set of query results. The flattening should kick in here. + sc.sql("insert into table flat_cached select * from test") + unionRDD = memoryTable.getRDD.get.asInstanceOf[UnionRDD[_]] + assert(isFlattenedUnionRDD(unionRDD)) + assert(unionRDD.rdds.size == numParentRDDs + 1) + } + + test("flatten UnionRDDs for partitioned tables") { + sc.sql("drop table if exists part_table_cached") + sc.sql("""create table part_table_cached(key int, value string) + partitioned by (keypart int)""") + sc.sql("alter table part_table_cached add partition(keypart = 1)") + sc.sql("insert into table part_table_cached partition(keypart = 1) select * from flat_cached") + val tableName = "part_table_cached" + val partitionKey = "keypart=1" + var partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable( + DEFAULT_DB_NAME, tableName).get + var unionRDD = partitionedTable.keyToPartitions.get(partitionKey).get.asInstanceOf[UnionRDD[_]] + val numParentRDDs = unionRDD.rdds.size + assert(isFlattenedUnionRDD(unionRDD)) + + // Insert another set of query results into the same partition. + // The flattening should kick in here. + sc.runSql("insert into table part_table_cached partition(keypart = 1) select * from flat_cached") + unionRDD = partitionedTable.getPartition(partitionKey).get.asInstanceOf[UnionRDD[_]] + assert(isFlattenedUnionRDD(unionRDD)) + assert(unionRDD.rdds.size == numParentRDDs + 1) } ////////////////////////////////////////////////////////////////////////////// @@ -322,11 +664,11 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.sql("drop table if exists adw") sc.sql("""create table adw TBLPROPERTIES ("shark.cache" = "true") as select cast(key as int) as k, val from test""") - expectSql("select count(k) from adw where val='val_487' group by 1 having count(1) > 0","1") + expectSql("select count(k) from adw where val='val_487' group by 1 having count(1) > 0", "1") } ////////////////////////////////////////////////////////////////////////////// - // Sel Star + // Partition pruning ////////////////////////////////////////////////////////////////////////////// test("sel star pruning") { @@ -336,11 +678,45 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { expectSql("select * from selstar where val='val_487'","487 val_487") } + test("map pruning with functions in between clause") { + sc.sql("drop table if exists mapsplitfunc") + sc.sql("drop table if exists mapsplitfunc_cached") + sc.sql("create table mapsplitfunc(k bigint, v string)") + sc.sql("""load data local inpath '${hiveconf:shark.test.data.path}/kv1.txt' + OVERWRITE INTO TABLE mapsplitfunc""") + sc.sql("create table mapsplitfunc_cached as select * from mapsplitfunc") + expectSql("""select count(*) from mapsplitfunc_cached + where month(from_unixtime(k)) between "1" and "12" """, Array[String]("500")) + expectSql("""select count(*) from mapsplitfunc_cached + where year(from_unixtime(k)) between "2013" and "2014" """, Array[String]("0")) + } + + ////////////////////////////////////////////////////////////////////////////// + // SharkContext APIs (e.g. sql2rdd, sql) + ////////////////////////////////////////////////////////////////////////////// + + test("cached table in different new database") { + sc.sql("drop table if exists selstar") + sc.sql("""create table selstar TBLPROPERTIES ("shark.cache" = "true") as + select * from default.test """) + sc.sql("use seconddb") + sc.sql("drop table if exists selstar") + sc.sql("""create table selstar TBLPROPERTIES ("shark.cache" = "true") as + select * from default.test where key != 'val_487' """) + + sc.sql("use default") + expectSql("select * from selstar where val='val_487'","487 val_487") + + assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "selstar")) + assert(SharkEnv.memoryMetadataManager.containsTable("seconddb", "selstar")) + + } + ////////////////////////////////////////////////////////////////////////////// // various data types ////////////////////////////////////////////////////////////////////////////// - test("various data types") { + test("boolean data type") { sc.sql("drop table if exists checkboolean") sc.sql("""create table checkboolean TBLPROPERTIES ("shark.cache" = "true") as select key, val, true as flag from test where key < "300" """) @@ -348,7 +724,9 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { select key, val, false as flag from test where key > "300" """) expectSql("select flag, count(*) from checkboolean group by flag order by flag asc", Array[String]("false\t208", "true\t292")) + } + test("byte data type") { sc.sql("drop table if exists checkbyte") sc.sql("drop table if exists checkbyte_cached") sc.sql("""create table checkbyte (key string, val string, flag tinyint) """) @@ -359,7 +737,10 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.sql("""create table checkbyte_cached as select * from checkbyte""") expectSql("select flag, count(*) from checkbyte_cached group by flag order by flag asc", Array[String]("0\t208", "1\t292")) + } + test("binary data type") { + sc.sql("drop table if exists checkbinary") sc.sql("drop table if exists checkbinary_cached") sc.sql("""create table checkbinary (key string, flag binary) """) @@ -370,7 +751,9 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.sql("create table checkbinary_cached as select key, flag from checkbinary") expectSql("select cast(flag as string) as f from checkbinary_cached order by f asc limit 2", Array[String]("val_0", "val_0")) + } + test("short data type") { sc.sql("drop table if exists checkshort") sc.sql("drop table if exists checkshort_cached") sc.sql("""create table checkshort (key string, val string, flag smallint) """) @@ -419,4 +802,288 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { val e = intercept[QueryExecutionException] { sc.sql2rdd("asdfasdfasdfasdf") } e.getMessage.contains("semantic") } + + ////////////////////////////////////////////////////////////////////////////// + // Default cache mode is CacheType.MEMORY (unified view) + ////////////////////////////////////////////////////////////////////////////// + test ("Table created by CREATE TABLE, with table properties, is CacheType.MEMORY by default") { + sc.runSql("drop table if exists test_unify_creation") + sc.runSql("""create table test_unify_creation (key int, val string) + tblproperties('shark.cache'='true')""") + val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_unify_creation").get + assert(table.cacheMode == CacheType.MEMORY) + sc.runSql("drop table if exists test_unify_creation") + } + + test ("Table created by CREATE TABLE, with '_cached', is CacheType.MEMORY by default") { + sc.runSql("drop table if exists test_unify_creation_cached") + sc.runSql("create table test_unify_creation_cached(key int, val string)") + val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_unify_creation_cached").get + assert(table.cacheMode == CacheType.MEMORY) + sc.runSql("drop table if exists test_unify_creation_cached") + } + + test ("Table created by CTAS, with table properties, is CacheType.MEMORY by default") { + sc.runSql("drop table if exists test_unify_ctas") + sc.runSql("""create table test_unify_ctas tblproperties('shark.cache' = 'true') + as select * from test""") + val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_unify_ctas").get + assert(table.cacheMode == CacheType.MEMORY) + expectSql("select count(*) from test_unify_ctas", "500") + sc.runSql("drop table if exists test_unify_ctas") + } + + test ("Table created by CTAS, with '_cached', is CacheType.MEMORY by default") { + sc.runSql("drop table if exists test_unify_ctas_cached") + sc.runSql("create table test_unify_ctas_cached as select * from test") + val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_unify_ctas_cached").get + assert(table.cacheMode == CacheType.MEMORY) + expectSql("select count(*) from test_unify_ctas_cached", "500") + sc.runSql("drop table if exists test_unify_ctas_cached") + } + + test ("CREATE TABLE when 'shark.cache' is CacheType.MEMORY_ONLY") { + sc.runSql("drop table if exists test_non_unify_creation") + sc.runSql("""create table test_non_unify_creation(key int, val string) + tblproperties('shark.cache' = 'memory_only')""") + val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_non_unify_creation").get + assert(table.cacheMode == CacheType.MEMORY_ONLY) + sc.runSql("drop table if exists test_non_unify_creation") + } + + test ("CTAS when 'shark.cache' is CacheType.MEMORY_ONLY") { + sc.runSql("drop table if exists test_non_unify_ctas") + sc.runSql("""create table test_non_unify_ctas tblproperties + ('shark.cache' = 'memory_only') as select * from test""") + val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_non_unify_ctas").get + assert(table.cacheMode == CacheType.MEMORY_ONLY) + sc.runSql("drop table if exists test_non_unify_ctas") + } + + ////////////////////////////////////////////////////////////////////////////// + // LOAD for tables cached in memory and stored on disk (unified view) + ////////////////////////////////////////////////////////////////////////////// + test ("LOAD INTO unified view") { + sc.runSql("drop table if exists unified_view_cached") + sc.runSql("create table unified_view_cached (key int, value string)") + sc.runSql("load data local inpath '%s' into table unified_view_cached".format(KV1_TXT_PATH)) + expectUnifiedKVTable("unified_view_cached") + expectSql("select count(*) from unified_view_cached", "500") + sc.runSql("drop table if exists unified_view_cached") + } + + test ("LOAD OVERWRITE unified view") { + sc.runSql("drop table if exists unified_overwrite_cached") + sc.runSql("create table unified_overwrite_cached (key int, value string)") + sc.runSql("load data local inpath '%s' into table unified_overwrite_cached". + format("${hiveconf:shark.test.data.path}/kv3.txt")) + expectSql("select count(*) from unified_overwrite_cached", "25") + sc.runSql("load data local inpath '%s' overwrite into table unified_overwrite_cached". + format(KV1_TXT_PATH)) + // Make sure the cached contents matches the disk contents. + expectUnifiedKVTable("unified_overwrite_cached") + expectSql("select count(*) from unified_overwrite_cached", "500") + sc.runSql("drop table if exists unified_overwrite_cached") + } + + test ("LOAD INTO partitioned unified view") { + sc.runSql("drop table if exists unified_view_part_cached") + sc.runSql("""create table unified_view_part_cached (key int, value string) + partitioned by (keypart int)""") + sc.runSql("""load data local inpath '%s' into table unified_view_part_cached + partition(keypart = 1)""".format(KV1_TXT_PATH)) + expectUnifiedKVTable("unified_view_part_cached", Some(Map("keypart" -> "1"))) + expectSql("select count(*) from unified_view_part_cached", "500") + sc.runSql("drop table if exists unified_view_part_cached") + } + + test ("LOAD OVERWRITE partitioned unified view") { + sc.runSql("drop table if exists unified_overwrite_part_cached") + sc.runSql("""create table unified_overwrite_part_cached (key int, value string) + partitioned by (keypart int)""") + sc.runSql("""load data local inpath '%s' overwrite into table unified_overwrite_part_cached + partition(keypart = 1)""".format(KV1_TXT_PATH)) + expectUnifiedKVTable("unified_overwrite_part_cached", Some(Map("keypart" -> "1"))) + expectSql("select count(*) from unified_overwrite_part_cached", "500") + sc.runSql("drop table if exists unified_overwrite_part_cached") + } + + ////////////////////////////////////////////////////////////////////////////// + // INSERT for tables cached in memory and stored on disk (unified view) + ////////////////////////////////////////////////////////////////////////////// + test ("INSERT INTO unified view") { + sc.runSql("drop table if exists unified_view_cached") + sc.runSql("create table unified_view_cached as select * from test_cached") + sc.runSql("insert into table unified_view_cached select * from test_cached") + expectUnifiedKVTable("unified_view_cached") + expectSql("select count(*) from unified_view_cached", "1000") + sc.runSql("drop table if exists unified_view_cached") + } + + test ("INSERT OVERWRITE unified view") { + sc.runSql("drop table if exists unified_overwrite_cached") + sc.runSql("create table unified_overwrite_cached as select * from test") + sc.runSql("insert overwrite table unified_overwrite_cached select * from test_cached") + expectUnifiedKVTable("unified_overwrite_cached") + expectSql("select count(*) from unified_overwrite_cached", "500") + sc.runSql("drop table if exists unified_overwrite_cached") + } + + test ("INSERT INTO partitioned unified view") { + sc.runSql("drop table if exists unified_view_part_cached") + sc.runSql("""create table unified_view_part_cached (key int, value string) + partitioned by (keypart int)""") + sc.runSql("""insert into table unified_view_part_cached partition (keypart = 1) + select * from test_cached""") + expectUnifiedKVTable("unified_view_part_cached", Some(Map("keypart" -> "1"))) + expectSql("select count(*) from unified_view_part_cached where keypart = 1", "500") + sc.runSql("drop table if exists unified_view_part_cached") + } + + test ("INSERT OVERWRITE partitioned unified view") { + sc.runSql("drop table if exists unified_overwrite_part_cached") + sc.runSql("""create table unified_overwrite_part_cached (key int, value string) + partitioned by (keypart int)""") + sc.runSql("""insert overwrite table unified_overwrite_part_cached partition (keypart = 1) + select * from test_cached""") + expectUnifiedKVTable("unified_overwrite_part_cached", Some(Map("keypart" -> "1"))) + expectSql("select count(*) from unified_overwrite_part_cached", "500") + sc.runSql("drop table if exists unified_overwrite_part_cached") + } + + ////////////////////////////////////////////////////////////////////////////// + // CACHE and ALTER TABLE commands + ////////////////////////////////////////////////////////////////////////////// + test ("ALTER TABLE caches non-partitioned table if 'shark.cache' is set to true") { + sc.runSql("drop table if exists unified_load") + sc.runSql("create table unified_load as select * from test") + sc.runSql("alter table unified_load set tblproperties('shark.cache' = 'true')") + expectUnifiedKVTable("unified_load") + sc.runSql("drop table if exists unified_load") + } + + test ("ALTER TABLE caches partitioned table if 'shark.cache' is set to true") { + sc.runSql("drop table if exists unified_part_load") + sc.runSql("create table unified_part_load (key int, value string) partitioned by (keypart int)") + sc.runSql("insert into table unified_part_load partition (keypart=1) select * from test_cached") + sc.runSql("alter table unified_part_load set tblproperties('shark.cache' = 'true')") + expectUnifiedKVTable("unified_part_load", Some(Map("keypart" -> "1"))) + sc.runSql("drop table if exists unified_part_load") + } + + test ("ALTER TABLE uncaches non-partitioned table if 'shark.cache' is set to false") { + sc.runSql("drop table if exists unified_load") + sc.runSql("create table unified_load as select * from test") + sc.runSql("alter table unified_load set tblproperties('shark.cache' = 'false')") + assert(!sharkMetastore.containsTable(DEFAULT_DB_NAME, "unified_load")) + expectSql("select count(*) from unified_load", "500") + sc.runSql("drop table if exists unified_load") + } + + test ("ALTER TABLE uncaches partitioned table if 'shark.cache' is set to false") { + sc.runSql("drop table if exists unified_part_load") + sc.runSql("create table unified_part_load (key int, value string) partitioned by (keypart int)") + sc.runSql("insert into table unified_part_load partition (keypart=1) select * from test_cached") + sc.runSql("alter table unified_part_load set tblproperties('shark.cache' = 'false')") + assert(!sharkMetastore.containsTable(DEFAULT_DB_NAME, "unified_part_load")) + expectSql("select count(*) from unified_part_load", "500") + sc.runSql("drop table if exists unified_part_load") + } + + test ("UNCACHE behaves like ALTER TABLE SET TBLPROPERTIES ...") { + sc.runSql("drop table if exists unified_load") + sc.runSql("create table unified_load as select * from test") + sc.runSql("cache unified_load") + // Double check the table properties. + val tableName = "unified_load" + val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, tableName) + assert(hiveTable.getProperty("shark.cache") == "MEMORY") + // Check that the cache and disk contents are synchronized. + expectUnifiedKVTable(tableName) + sc.runSql("drop table if exists unified_load") + } + + test ("CACHE behaves like ALTER TABLE SET TBLPROPERTIES ...") { + sc.runSql("drop table if exists unified_load") + sc.runSql("create table unified_load as select * from test") + sc.runSql("cache unified_load") + // Double check the table properties. + val tableName = "unified_load" + val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, tableName) + assert(hiveTable.getProperty("shark.cache") == "MEMORY") + // Check that the cache and disk contents are synchronized. + expectUnifiedKVTable(tableName) + sc.runSql("drop table if exists unified_load") + } + + ////////////////////////////////////////////////////////////////////////////// + // Cached table persistence + ////////////////////////////////////////////////////////////////////////////// + test ("Cached tables persist across Shark metastore shutdowns.") { + val globalCachedTableNames = Seq("test_cached", "test_null_cached", "clicks_cached", + "users_cached", "test1_cached") + + // Number of rows for each cached table. + val cachedTableCounts = new Array[String](globalCachedTableNames.size) + for ((tableName, i) <- globalCachedTableNames.zipWithIndex) { + val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, tableName) + val cachedCount = sc.sql("select count(*) from %s".format(tableName))(0) + cachedTableCounts(i) = cachedCount + } + sharkMetastore.shutdown() + for ((tableName, i) <- globalCachedTableNames.zipWithIndex) { + val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, tableName) + // Check that the number of rows from the table on disk remains the same. + val onDiskCount = sc.sql("select count(*) from %s".format(tableName))(0) + val cachedCount = cachedTableCounts(i) + assert(onDiskCount == cachedCount, """Num rows for %s differ across Shark metastore restart. + (rows cached = %s, rows on disk = %s)""".format(tableName, cachedCount, onDiskCount)) + // Check that we're able to materialize a row - i.e., make sure that table scan operator + // doesn't try to use a ColumnarSerDe when scanning contents on disk (for our test tables, + // LazySimpleSerDes should be used). + sc.sql("select * from %s limit 1".format(tableName)) + } + // Finally, reload all tables. + SharkRunner.loadTables() + } + + ////////////////////////////////////////////////////////////////////////////// + // Table Generating Functions (TGFs) + ////////////////////////////////////////////////////////////////////////////// + + test("Simple TGFs") { + expectSql("generate shark.TestTGF1(test, 15)", Array(15,15,15,17,19).map(_.toString).toArray) + } + + test("Saving simple TGFs") { + sc.sql("drop table if exists TGFTestTable") + sc.runSql("generate shark.TestTGF1(test, 15) as TGFTestTable") + expectSql("select * from TGFTestTable", Array(15,15,15,17,19).map(_.toString).toArray) + sc.sql("drop table if exists TGFTestTable") + } + + test("Advanced TGFs") { + expectSql("generate shark.TestTGF2(test, 25)", Array(25,25,25,27,29).map(_.toString).toArray) + } + + test("Saving advanced TGFs") { + sc.sql("drop table if exists TGFTestTable2") + sc.runSql("generate shark.TestTGF2(test, 25) as TGFTestTable2") + expectSql("select * from TGFTestTable2", Array(25,25,25,27,29).map(_.toString).toArray) + sc.sql("drop table if exists TGFTestTable2") + } +} + +object TestTGF1 { + @Schema(spec = "values int") + def apply(test: RDD[(Int, String)], integer: Int) = { + test.map{ case Tuple2(k, v) => Tuple1(k + integer) }.filter{ case Tuple1(v) => v < 20 } + } +} + +object TestTGF2 { + def apply(sc: SharkContext, test: RDD[(Int, String)], integer: Int) = { + val rdd = test.map{ case Tuple2(k, v) => Seq(k + integer) }.filter{ case Seq(v) => v < 30 } + RDDSchema(rdd.asInstanceOf[RDD[Seq[_]]], "myvalues int") + } } diff --git a/src/test/scala/shark/SharkRunner.scala b/src/test/scala/shark/SharkRunner.scala new file mode 100644 index 00000000..573ecec2 --- /dev/null +++ b/src/test/scala/shark/SharkRunner.scala @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark + +import org.apache.hadoop.hive.metastore.MetaStoreUtils.DEFAULT_DATABASE_NAME + +import shark.api.JavaSharkContext +import shark.memstore2.MemoryMetadataManager + + +object SharkRunner { + + val WAREHOUSE_PATH = TestUtils.getWarehousePath() + val METASTORE_PATH = TestUtils.getMetastorePath() + val MASTER = "local" + + var sc: SharkContext = _ + + var javaSc: JavaSharkContext = _ + + def init(): SharkContext = synchronized { + if (sc == null) { + sc = SharkEnv.initWithSharkContext("shark-sql-suite-testing", MASTER) + + sc.runSql("set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=" + + METASTORE_PATH + ";create=true") + sc.runSql("set hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) + sc.runSql("set shark.test.data.path=" + TestUtils.dataFilePath) + + // second db + sc.sql("create database if not exists seconddb") + + loadTables() + } + sc + } + + def initWithJava(): JavaSharkContext = synchronized { + if (javaSc == null) { + javaSc = new JavaSharkContext(init()) + } + javaSc + } + + /** + * Tables accessible by any test. Their properties should remain constant across + * tests. + */ + def loadTables() = synchronized { + require(sc != null, "call init() to instantiate a SharkContext first") + + // Use the default namespace + sc.runSql("USE " + DEFAULT_DATABASE_NAME) + + // test + sc.runSql("drop table if exists test") + sc.runSql("CREATE TABLE test (key INT, val STRING)") + sc.runSql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/kv1.txt' INTO TABLE test") + sc.runSql("drop table if exists test_cached") + sc.runSql("CREATE TABLE test_cached AS SELECT * FROM test") + + // test_null + sc.runSql("drop table if exists test_null") + sc.runSql("CREATE TABLE test_null (key INT, val STRING)") + sc.runSql("""LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/kv3.txt' + INTO TABLE test_null""") + sc.runSql("drop table if exists test_null_cached") + sc.runSql("CREATE TABLE test_null_cached AS SELECT * FROM test_null") + + // clicks + sc.runSql("drop table if exists clicks") + sc.runSql("""create table clicks (id int, click int) + row format delimited fields terminated by '\t'""") + sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/clicks.txt' + OVERWRITE INTO TABLE clicks""") + sc.runSql("drop table if exists clicks_cached") + sc.runSql("create table clicks_cached as select * from clicks") + + // users + sc.runSql("drop table if exists users") + sc.runSql("""create table users (id int, name string) + row format delimited fields terminated by '\t'""") + sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/users.txt' + OVERWRITE INTO TABLE users""") + sc.runSql("drop table if exists users_cached") + sc.runSql("create table users_cached as select * from users") + + // test1 + sc.sql("drop table if exists test1") + sc.sql("""CREATE TABLE test1 (id INT, test1val ARRAY) + row format delimited fields terminated by '\t'""") + sc.sql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/test1.txt' INTO TABLE test1") + sc.sql("drop table if exists test1_cached") + sc.sql("CREATE TABLE test1_cached AS SELECT * FROM test1") + Unit + } + + def expectSql(sql: String, expectedResults: Array[String], sort: Boolean = true) { + val sharkResults: Array[String] = sc.runSql(sql).results.map(_.mkString("\t")).toArray + val results = if (sort) sharkResults.sortWith(_ < _) else sharkResults + val expected = if (sort) expectedResults.sortWith(_ < _) else expectedResults + assert(results.corresponds(expected)(_.equals(_)), + "In SQL: " + sql + "\n" + + "Expected: " + expected.mkString("\n") + "; got " + results.mkString("\n")) + } + + // A shortcut for single row results. + def expectSql(sql: String, expectedResult: String) { + expectSql(sql, Array(expectedResult)) + } + +} diff --git a/src/test/scala/shark/SharkServerSuite.scala b/src/test/scala/shark/SharkServerSuite.scala index e5df4f98..1310ca04 100644 --- a/src/test/scala/shark/SharkServerSuite.scala +++ b/src/test/scala/shark/SharkServerSuite.scala @@ -10,7 +10,8 @@ import scala.collection.JavaConversions._ import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.matchers.ShouldMatchers -import scala.concurrent.ops._ +import scala.concurrent._ +import ExecutionContext.Implicits.global /** * Test for the Shark server. @@ -57,7 +58,7 @@ class SharkServerSuite extends FunSuite with BeforeAndAfterAll with ShouldMatche // Spawn a thread to read the output from the forked process. // Note that this is necessary since in some configurations, log4j could be blocked // if its output to stderr are not read, and eventually blocking the entire test suite. - spawn { + future { while (true) { val stdout = readFrom(inputReader) val stderr = readFrom(errorReader) @@ -78,6 +79,7 @@ class SharkServerSuite extends FunSuite with BeforeAndAfterAll with ShouldMatche } test("test query execution against a shark server") { + Thread.sleep(5*1000) // I know... Gross. However, without this the tests fail non-deterministically. val dataFilePath = TestUtils.dataFilePath + "/kv1.txt" val stmt = createStatement() diff --git a/src/test/scala/shark/SortSuite.scala b/src/test/scala/shark/SortSuite.scala index 4e7e9c05..df948a54 100644 --- a/src/test/scala/shark/SortSuite.scala +++ b/src/test/scala/shark/SortSuite.scala @@ -31,28 +31,23 @@ class SortSuite extends FunSuite { TestUtils.init() + var sc: SparkContext = SharkRunner.init() + test("order by limit") { - var sc: SparkContext = null - try { - sc = new SparkContext("local", "test") - val data = Array((4, 14), (1, 11), (7, 17), (0, 10)) - val expected = data.sortWith(_._1 < _._1).toSeq - val rdd: RDD[(ReduceKey, BytesWritable)] = sc.parallelize(data, 50).map { x => - (new ReduceKeyMapSide(new BytesWritable(Array[Byte](x._1.toByte))), - new BytesWritable(Array[Byte](x._2.toByte))) - } - for (k <- 0 to 5) { - val sortedRdd = RDDUtils.topK(rdd, k).asInstanceOf[RDD[(ReduceKeyReduceSide, Array[Byte])]] - val output = sortedRdd.map { case(k, v) => - (k.byteArray(0).toInt, v(0).toInt) - }.collect().toSeq - assert(output.size === math.min(k, 4)) - assert(output === expected.take(math.min(k, 4))) - } - } finally { - sc.stop() + val data = Array((4, 14), (1, 11), (7, 17), (0, 10)) + val expected = data.sortWith(_._1 < _._1).toSeq + val rdd: RDD[(ReduceKey, BytesWritable)] = sc.parallelize(data, 50).map { x => + (new ReduceKeyMapSide(new BytesWritable(Array[Byte](x._1.toByte))), + new BytesWritable(Array[Byte](x._2.toByte))) + } + for (k <- 0 to 5) { + val sortedRdd = RDDUtils.topK(rdd, k).asInstanceOf[RDD[(ReduceKeyReduceSide, Array[Byte])]] + val output = sortedRdd.map { case(k, v) => + (k.byteArray(0).toInt, v(0).toInt) + }.collect().toSeq + assert(output.size === math.min(k, 4)) + assert(output === expected.take(math.min(k, 4))) } - sc.stop() - System.clearProperty("spark.driver.port") } + } diff --git a/src/test/scala/shark/TachyonSQLSuite.scala b/src/test/scala/shark/TachyonSQLSuite.scala new file mode 100644 index 00000000..899bc1d4 --- /dev/null +++ b/src/test/scala/shark/TachyonSQLSuite.scala @@ -0,0 +1,437 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark + +import java.util.{HashMap => JavaHashMap} + +import scala.collection.JavaConversions._ + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.hadoop.hive.metastore.MetaStoreUtils.DEFAULT_DATABASE_NAME +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.spark.rdd.UnionRDD +import org.apache.spark.storage.StorageLevel + +import shark.api.QueryExecutionException +import shark.memstore2.{CacheType, MemoryMetadataManager, PartitionedMemoryTable} +// import expectSql() shortcut methods +import shark.SharkRunner._ + + +class TachyonSQLSuite extends FunSuite with BeforeAndAfterAll { + + val DEFAULT_DB_NAME = DEFAULT_DATABASE_NAME + val KV1_TXT_PATH = "${hiveconf:shark.test.data.path}/kv1.txt" + + var sc: SharkContext = SharkRunner.init() + var sharkMetastore: MemoryMetadataManager = SharkEnv.memoryMetadataManager + + // Determine if Tachyon enabled at runtime. + val isTachyonEnabled = SharkEnv.tachyonUtil.tachyonEnabled() + + + override def beforeAll() { + if (isTachyonEnabled) { + sc.runSql("create table test_tachyon as select * from test") + } + } + + override def afterAll() { + if (isTachyonEnabled) { + sc.runSql("drop table test_tachyon") + } + } + + private def isTachyonTable( + dbName: String, + tableName: String, + hivePartitionKeyOpt: Option[String] = None): Boolean = { + val tableKey = MemoryMetadataManager.makeTableKey(dbName, tableName) + SharkEnv.tachyonUtil.tableExists(tableKey, hivePartitionKeyOpt) + } + + private def createPartitionedTachyonTable(tableName: String, numPartitionsToCreate: Int) { + sc.runSql("drop table if exists %s".format(tableName)) + sc.runSql(""" + create table %s(key int, value string) + partitioned by (keypart int) + tblproperties('shark.cache' = 'tachyon') + """.format(tableName)) + var partitionNum = 1 + while (partitionNum <= numPartitionsToCreate) { + sc.runSql("""insert into table %s partition(keypart = %d) + select * from test_tachyon""".format(tableName, partitionNum)) + partitionNum += 1 + } + assert(isTachyonTable(DEFAULT_DB_NAME, tableName)) + } + + if (isTachyonEnabled) { + ////////////////////////////////////////////////////////////////////////////// + // basic SQL + ////////////////////////////////////////////////////////////////////////////// + test("count") { + expectSql("select count(*) from test_tachyon", "500") + } + + test("filter") { + expectSql("select * from test_tachyon where key=100 or key=497", + Array("100\tval_100", "100\tval_100", "497\tval_497")) + } + + test("count distinct") { + sc.runSql("set mapred.reduce.tasks=3") + expectSql("select count(distinct key) from test_tachyon", "309") + expectSql( + """|SELECT substr(key,1,1), count(DISTINCT substr(val,5)) from test_tachyon + |GROUP BY substr(key,1,1)""".stripMargin, + Array("0\t1", "1\t71", "2\t69", "3\t62", "4\t74", "5\t6", "6\t5", "7\t6", "8\t8", "9\t7")) + } + + test("count bigint") { + sc.runSql("drop table if exists test_bigint") + sc.runSql("create table test_bigint (key bigint, val string)") + sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/kv1.txt' + OVERWRITE INTO TABLE test_bigint""") + sc.runSql("drop table if exists test_bigint_tachyon") + sc.runSql("create table test_bigint_tachyon as select * from test_bigint") + expectSql("select val, count(*) from test_bigint_tachyon where key=484 group by val", + "val_484\t1") + + sc.runSql("drop table if exists test_bigint_tachyon") + } + + test("limit") { + assert(sc.runSql("select * from test_tachyon limit 10").results.length === 10) + assert(sc.runSql("select * from test_tachyon limit 501").results.length === 500) + sc.runSql("drop table if exists test_limit0_tachyon") + assert(sc.runSql("select * from test_tachyon limit 0").results.length === 0) + assert(sc.runSql("create table test_limit0_tachyon as select * from test_tachyon limit 0") + .results.length === 0) + assert(sc.runSql("select * from test_limit0_tachyon limit 0").results.length === 0) + assert(sc.runSql("select * from test_limit0_tachyon limit 1").results.length === 0) + + sc.runSql("drop table if exists test_limit0_tachyon") + } + + ////////////////////////////////////////////////////////////////////////////// + // cache DDL + ////////////////////////////////////////////////////////////////////////////// + test("Use regular CREATE TABLE and '_tachyon' suffix to create Tachyon table") { + sc.runSql("drop table if exists empty_table_tachyon") + sc.runSql("create table empty_table_tachyon(key string, value string)") + assert(isTachyonTable(DEFAULT_DB_NAME, "empty_table_tachyon")) + + sc.runSql("drop table if exists empty_table_tachyon") + } + + test("Use regular CREATE TABLE and table properties to create Tachyon table") { + sc.runSql("drop table if exists empty_table_tachyon_tbl_props") + sc.runSql("""create table empty_table_tachyon_tbl_props(key string, value string) + TBLPROPERTIES('shark.cache' = 'tachyon')""") + assert(isTachyonTable(DEFAULT_DB_NAME, "empty_table_tachyon_tbl_props")) + + sc.runSql("drop table if exists empty_table_tachyon_tbl_props") + } + + test("Insert into empty Tachyon table") { + sc.runSql("drop table if exists new_table_tachyon") + sc.runSql("create table new_table_tachyon(key string, value string)") + sc.runSql("insert into table new_table_tachyon select * from test where key > -1 limit 499") + expectSql("select count(*) from new_table_tachyon", "499") + + sc.runSql("drop table if exists new_table_tachyon") + } + + test("rename Tachyon table") { + sc.runSql("drop table if exists test_oldname_tachyon") + sc.runSql("drop table if exists test_rename") + sc.runSql("create table test_oldname_tachyon as select * from test") + sc.runSql("alter table test_oldname_tachyon rename to test_rename") + + assert(!isTachyonTable(DEFAULT_DB_NAME, "test_oldname_tachyon")) + assert(isTachyonTable(DEFAULT_DB_NAME, "test_rename")) + + expectSql("select count(*) from test_rename", "500") + + sc.runSql("drop table if exists test_rename") + } + + test("insert into tachyon tables") { + sc.runSql("drop table if exists test1_tachyon") + sc.runSql("create table test1_tachyon as select * from test") + expectSql("select count(*) from test1_tachyon", "500") + sc.runSql("insert into table test1_tachyon select * from test where key > -1 limit 499") + expectSql("select count(*) from test1_tachyon", "999") + + sc.runSql("drop table if exists test1_tachyon") + } + + test("insert overwrite") { + sc.runSql("drop table if exists test2_tachyon") + sc.runSql("create table test2_tachyon as select * from test") + expectSql("select count(*) from test2_tachyon", "500") + sc.runSql("insert overwrite table test2_tachyon select * from test where key > -1 limit 499") + expectSql("select count(*) from test2_tachyon", "499") + + sc.runSql("drop table if exists test2_tachyon") + } + + test("error when attempting to update Tachyon table(s) using command with multiple INSERTs") { + sc.runSql("drop table if exists multi_insert_test") + sc.runSql("drop table if exists multi_insert_test_tachyon") + sc.runSql("create table multi_insert_test as select * from test") + sc.runSql("create table multi_insert_test_tachyon as select * from test") + intercept[QueryExecutionException] { + sc.runSql("""from test + insert into table multi_insert_test select * + insert into table multi_insert_test_tachyon select *""") + } + + sc.runSql("drop table if exists multi_insert_test") + sc.runSql("drop table if exists multi_insert_test_tachyon") + } + + test("create Tachyon table with 'shark.cache' flag in table properties") { + sc.runSql("drop table if exists ctas_tbl_props") + sc.runSql("""create table ctas_tbl_props TBLPROPERTIES ('shark.cache'='tachyon') as + select * from test""") + assert(isTachyonTable(DEFAULT_DB_NAME, "ctas_tbl_props")) + expectSql("select * from ctas_tbl_props where key=407", "407\tval_407") + + sc.runSql("drop table if exists ctas_tbl_props") + } + + test("tachyon tables with complex types") { + sc.runSql("drop table if exists test_complex_types") + sc.runSql("drop table if exists test_complex_types_tachyon") + sc.runSql("""CREATE TABLE test_complex_types ( + a STRING, b ARRAY, c ARRAY>, d MAP>)""") + sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/create_nested_type.txt' + overwrite into table test_complex_types""") + sc.runSql("""create table test_complex_types_tachyon TBLPROPERTIES ("shark.cache" = "tachyon") + as select * from test_complex_types""") + + assert(sc.sql("select a from test_complex_types_tachyon where a = 'a0'").head === "a0") + + assert(sc.sql("select b from test_complex_types_tachyon where a = 'a0'").head === + """["b00","b01"]""") + + assert(sc.sql("select c from test_complex_types_tachyon where a = 'a0'").head === + """[{"c001":"C001","c002":"C002"},{"c011":null,"c012":"C012"}]""") + + assert(sc.sql("select d from test_complex_types_tachyon where a = 'a0'").head === + """{"d01":["d011","d012"],"d02":["d021","d022"]}""") + + assert(isTachyonTable(DEFAULT_DB_NAME, "test_complex_types_tachyon")) + + sc.runSql("drop table if exists test_complex_types") + sc.runSql("drop table if exists test_complex_types_tachyon") + } + + test("disable caching in Tachyon by default") { + sc.runSql("set shark.cache.flag.checkTableName=false") + sc.runSql("drop table if exists should_not_be_in_tachyon") + sc.runSql("create table should_not_be_in_tachyon as select * from test") + expectSql("select key from should_not_be_in_tachyon where key = 407", "407") + assert(!isTachyonTable(DEFAULT_DB_NAME, "should_not_be_in_tachyon")) + + sc.runSql("set shark.cache.flag.checkTableName=true") + sc.runSql("drop table if exists should_not_be_in_tachyon") + } + + test("tachyon table name should be case-insensitive") { + sc.runSql("drop table if exists sharkTest5tachyon") + sc.runSql("""create table sharkTest5tachyon TBLPROPERTIES ("shark.cache" = "tachyon") as + select * from test""") + expectSql("select val from sharktest5tachyon where key = 407", "val_407") + assert(isTachyonTable(DEFAULT_DB_NAME, "sharkTest5tachyon")) + + sc.runSql("drop table if exists sharkTest5tachyon") + } + + test("dropping tachyon tables should clean up RDDs") { + sc.runSql("drop table if exists sharkTest5tachyon") + sc.runSql("""create table sharkTest5tachyon TBLPROPERTIES ("shark.cache" = "tachyon") as + select * from test""") + sc.runSql("drop table sharkTest5tachyon") + assert(!isTachyonTable(DEFAULT_DB_NAME, "sharkTest5tachyon")) + } + + ////////////////////////////////////////////////////////////////////////////// + // Caching Hive-partititioned tables + // Note: references to 'partition' for this section refer to a Hive-partition. + ////////////////////////////////////////////////////////////////////////////// + test("Use regular CREATE TABLE and '_tachyon' suffix to create partitioned Tachyon table") { + sc.runSql("drop table if exists empty_part_table_tachyon") + sc.runSql("""create table empty_part_table_tachyon(key int, value string) + partitioned by (keypart int)""") + assert(isTachyonTable(DEFAULT_DB_NAME, "empty_part_table_tachyon")) + + sc.runSql("drop table if exists empty_part_table_tachyon") + } + + test("Use regular CREATE TABLE and table properties to create partitioned Tachyon table") { + sc.runSql("drop table if exists empty_part_table_tachyon_tbl_props") + sc.runSql("""create table empty_part_table_tachyon_tbl_props(key int, value string) + partitioned by (keypart int) tblproperties('shark.cache' = 'tachyon')""") + assert(isTachyonTable(DEFAULT_DB_NAME, "empty_part_table_tachyon_tbl_props")) + + sc.runSql("drop table if exists empty_part_table_tachyon_tbl_props") + } + + test("alter Tachyon table by adding a new partition") { + sc.runSql("drop table if exists alter_part_tachyon") + sc.runSql("""create table alter_part_tachyon(key int, value string) + partitioned by (keypart int)""") + sc.runSql("""alter table alter_part_tachyon add partition(keypart = 1)""") + val tableName = "alter_part_tachyon" + val partitionColumn = "keypart=1" + assert(isTachyonTable(DEFAULT_DB_NAME, "alter_part_tachyon", Some(partitionColumn))) + + sc.runSql("drop table if exists alter_part_tachyon") + } + + test("alter Tachyon table by dropping a partition") { + sc.runSql("drop table if exists alter_drop_tachyon") + sc.runSql("""create table alter_drop_tachyon(key int, value string) + partitioned by (keypart int)""") + sc.runSql("""alter table alter_drop_tachyon add partition(keypart = 1)""") + + val tableName = "alter_drop_tachyon" + val partitionColumn = "keypart=1" + assert(isTachyonTable(DEFAULT_DB_NAME, "alter_drop_tachyon", Some(partitionColumn))) + sc.runSql("""alter table alter_drop_tachyon drop partition(keypart = 1)""") + assert(!isTachyonTable(DEFAULT_DB_NAME, "alter_drop_tachyon", Some(partitionColumn))) + + sc.runSql("drop table if exists alter_drop_tachyon") + } + + test("insert into a partition of a Tachyon table") { + val tableName = "insert_part_tachyon" + createPartitionedTachyonTable( + tableName, + numPartitionsToCreate = 1) + expectSql("select value from insert_part_tachyon where key = 407 and keypart = 1", "val_407") + + sc.runSql("drop table if exists insert_part_tachyon") + } + + test("insert overwrite a partition of a Tachyon table") { + val tableName = "insert_over_part_tachyon" + createPartitionedTachyonTable( + tableName, + numPartitionsToCreate = 1) + expectSql("""select value from insert_over_part_tachyon + where key = 407 and keypart = 1""", "val_407") + sc.runSql("""insert overwrite table insert_over_part_tachyon partition(keypart = 1) + select key, -1 from test""") + expectSql("select value from insert_over_part_tachyon where key = 407 and keypart = 1", "-1") + + sc.runSql("drop table if exists insert_over_part_tachyon") + } + + test("scan partitioned Tachyon table that's empty") { + sc.runSql("drop table if exists empty_part_table_tachyon") + sc.runSql("""create table empty_part_table_tachyon(key int, value string) + partitioned by (keypart int)""") + expectSql("select count(*) from empty_part_table_tachyon", "0") + + sc.runSql("drop table if exists empty_part_table_tachyon") + } + + test("scan partitioned Tachyon table that has a single partition") { + val tableName = "scan_single_part_tachyon" + createPartitionedTachyonTable( + tableName, + numPartitionsToCreate = 1) + expectSql("select * from scan_single_part_tachyon where key = 407", "407\tval_407\t1") + + sc.runSql("drop table if exists scan_single_part_tachyon") + } + + test("scan partitioned Tachyon table that has multiple partitions") { + val tableName = "scan_mult_part_tachyon" + createPartitionedTachyonTable( + tableName, + numPartitionsToCreate = 3) + expectSql("select * from scan_mult_part_tachyon where key = 407 order by keypart", + Array("407\tval_407\t1", "407\tval_407\t2", "407\tval_407\t3")) + + sc.runSql("drop table if exists scan_mult_part_tachyon") + } + + test("drop/unpersist partitioned Tachyon table that has multiple partitions") { + val tableName = "drop_mult_part_tachyon" + createPartitionedTachyonTable( + tableName, + numPartitionsToCreate = 3) + expectSql("select count(1) from drop_mult_part_tachyon", "1500") + sc.runSql("drop table drop_mult_part_tachyon ") + assert(!isTachyonTable(DEFAULT_DB_NAME, tableName)) + + sc.runSql("drop table if exists drop_mult_part_tachyon") + } + + ///////////////////////////////////////////////////////////////////////////// + // LOAD for Tachyon tables + ////////////////////////////////////////////////////////////////////////////// + test ("LOAD INTO a Tachyon table") { + sc.runSql("drop table if exists load_into_tachyon") + sc.runSql("create table load_into_tachyon (key int, value string)") + sc.runSql("load data local inpath '%s' into table load_into_tachyon".format(KV1_TXT_PATH)) + expectSql("select count(*) from load_into_tachyon", "500") + + sc.runSql("drop table if exists load_into_tachyon") + } + + test ("LOAD OVERWRITE a Tachyon table") { + sc.runSql("drop table if exists load_overwrite_tachyon") + sc.runSql("create table load_overwrite_tachyon (key int, value string)") + sc.runSql("load data local inpath '%s' into table load_overwrite_tachyon". + format("${hiveconf:shark.test.data.path}/kv3.txt")) + expectSql("select count(*) from load_overwrite_tachyon", "25") + sc.runSql("load data local inpath '%s' overwrite into table load_overwrite_tachyon". + format(KV1_TXT_PATH)) + expectSql("select count(*) from load_overwrite_tachyon", "500") + sc.runSql("drop table if exists load_overwrite_tachyon") + } + + test ("LOAD INTO a partitioned Tachyon table") { + sc.runSql("drop table if exists load_into_part_tachyon") + sc.runSql("""create table load_into_part_tachyon (key int, value string) + partitioned by (keypart int)""") + sc.runSql("""load data local inpath '%s' into table load_into_part_tachyon + partition(keypart = 1)""".format(KV1_TXT_PATH)) + expectSql("select count(*) from load_into_part_tachyon", "500") + sc.runSql("drop table if exists load_into_part_tachyon") + } + + test ("LOAD OVERWRITE a partitioned Tachyon table") { + sc.runSql("drop table if exists load_overwrite_part_tachyon") + sc.runSql("""create table load_overwrite_part_tachyon (key int, value string) + partitioned by (keypart int)""") + sc.runSql("""load data local inpath '%s' overwrite into table load_overwrite_part_tachyon + partition(keypart = 1)""".format(KV1_TXT_PATH)) + expectSql("select count(*) from load_overwrite_part_tachyon", "500") + sc.runSql("drop table if exists load_overwrite_part_tachyon") + } + } +} diff --git a/src/test/scala/shark/TestUtils.scala b/src/test/scala/shark/TestUtils.scala index df2c264a..8bf0fd6f 100644 --- a/src/test/scala/shark/TestUtils.scala +++ b/src/test/scala/shark/TestUtils.scala @@ -24,6 +24,9 @@ import java.util.{Date, HashMap => JHashMap} import org.apache.hadoop.hive.common.LogUtils import org.apache.hadoop.hive.common.LogUtils.LogInitializationException +import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.storage.StorageLevel + object TestUtils { @@ -48,6 +51,42 @@ object TestUtils { } } + def getStorageLevelOfRDD(rdd: RDD[_]): StorageLevel = { + rdd match { + case u: UnionRDD[_] => { + // Find the storage level of a UnionRDD from the storage levels of RDDs that compose it. + // A StorageLevel.NONE is returned if all of those RDDs have StorageLevel.NONE. + // Mutually recursive if any RDD in 'u.rdds' is a UnionRDD. + getStorageLevelOfRDDs(u.rdds) + } + case _ => rdd.getStorageLevel + } + } + + /** + * Returns the storage level of a sequence of RDDs, interpreted as the storage level of the first + * RDD in the sequence that is persisted in memory or on disk. This works because for Shark's use + * case, all RDDs for a non-partitioned table should have the same storage level. An RDD for a + * partitioned table could be StorageLevel.NONE if it was unpersisted by the partition eviction + * policy. + * + * @param rdds The sequence of RDDs to find the StorageLevel of. + */ + def getStorageLevelOfRDDs(rdds: Seq[RDD[_]]): StorageLevel = { + rdds.foldLeft(StorageLevel.NONE) { + (s, r) => { + if (s == StorageLevel.NONE) { + // Mutally recursive if `r` is a UnionRDD. However, this shouldn't happen in Shark, since + // UnionRDDs from successive INSERT INTOs are created through #unionAndFlatten(). + getStorageLevelOfRDD(r) + } else { + // Some RDD in 'rdds' is persisted in memory or disk, so return early. + return s + } + } + } + } + // Don't use default arguments in the above functions because otherwise the JavaAPISuite // can't call those functions with default arguments. def getWarehousePath(): String = getWarehousePath("sql") diff --git a/src/test/scala/shark/execution/serialization/SerializationSuite.scala b/src/test/scala/shark/execution/serialization/SerializationSuite.scala index f7f887e1..43c97a0c 100755 --- a/src/test/scala/shark/execution/serialization/SerializationSuite.scala +++ b/src/test/scala/shark/execution/serialization/SerializationSuite.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.FunSuite +import org.apache.spark.SparkConf import org.apache.spark.serializer.{JavaSerializer => SparkJavaSerializer} @@ -51,7 +52,7 @@ class SerializationSuite extends FunSuite { val ois = KryoSerializationWrapper(new ArrayBuffer[ObjectInspector]) ois.value += oi - val ser = new SparkJavaSerializer + val ser = new SparkJavaSerializer(new SparkConf(loadDefaults = false)) val bytes = ser.newInstance().serialize(ois) val desered = ser.newInstance() .deserialize[KryoSerializationWrapper[ArrayBuffer[ObjectInspector]]](bytes) @@ -59,14 +60,6 @@ class SerializationSuite extends FunSuite { assert(desered.head.getTypeName() === oi.getTypeName()) } - test("HiveConf serialization test") { - val hiveConf = new HiveConf - val bytes = HiveConfSerializer.serialize(hiveConf) - val deseredConf = HiveConfSerializer.deserialize(bytes) - - assertHiveConfEquals(hiveConf, deseredConf) - } - test("Java serializing operators") { import shark.execution.{FileSinkOperator => SharkFileSinkOperator} @@ -75,7 +68,7 @@ class SerializationSuite extends FunSuite { operator.localHiveOp = new org.apache.hadoop.hive.ql.exec.FileSinkOperator val opWrapped = OperatorSerializationWrapper(operator) - val ser = new SparkJavaSerializer + val ser = new SparkJavaSerializer(new SparkConf(loadDefaults = false)) val bytes = ser.newInstance().serialize(opWrapped) val desered = ser.newInstance() .deserialize[OperatorSerializationWrapper[SharkFileSinkOperator]](bytes) diff --git a/src/test/scala/shark/memstore2/CachePolicySuite.scala b/src/test/scala/shark/memstore2/CachePolicySuite.scala new file mode 100644 index 00000000..9fe41d4a --- /dev/null +++ b/src/test/scala/shark/memstore2/CachePolicySuite.scala @@ -0,0 +1,131 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package shark.memstore2 + +import org.scalatest.FunSuite + +import scala.collection.mutable.HashMap + +class CachePolicySuite extends FunSuite { + + case class TestValue(var value: Int, var isCached: Boolean) + + class IdentifyKVGen(max: Int) { + val kvMap = new HashMap[Int, TestValue]() + for (i <- 0 until max) { + kvMap(i) = TestValue(i, isCached = false) + } + + def loadFunc(key: Int) = { + val value = kvMap(key) + value.isCached = true + value + } + + def evictionFunc(key: Int, value: TestValue) = { + value.isCached = false + } + } + + test("LRU policy") { + val kvGen = new IdentifyKVGen(20) + val cacheSize = 10 + val lru = new LRUCachePolicy[Int, TestValue]() + lru.initialize(Array.empty[String], cacheSize, kvGen.loadFunc _, kvGen.evictionFunc _) + + // Load KVs 0-9. + (0 to 9).map(lru.notifyGet(_)) + assert(lru.keysOfCachedEntries.equals(Seq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))) + + // Reorder access order by getting keys 2-4. + (2 to 4).map(lru.notifyGet(_)) + assert(lru.keysOfCachedEntries.equals(Seq(0, 1, 5, 6, 7, 8, 9, 2, 3, 4))) + + // Get keys 10-12, which should evict (0, 1, 5). + (10 to 12).map(lru.notifyGet(_)) + assert(lru.keysOfCachedEntries.equals(Seq(6, 7, 8, 9, 2, 3, 4, 10, 11, 12))) + // Make sure the eviction function ran. + assert(!kvGen.kvMap(0).isCached) + assert(!kvGen.kvMap(1).isCached) + assert(!kvGen.kvMap(5).isCached) + + // Reorder access order by getting keys (6, 8, 2). + lru.notifyGet(6); lru.notifyGet(8); lru.notifyGet(2) + assert(lru.keysOfCachedEntries.equals(Seq(7, 9, 3, 4, 10, 11, 12, 6, 8, 2))) + + // Remove 9, 4 and add 13, 14, 15. 7 should be evicted. + lru.notifyRemove(9); lru.notifyRemove(4) + (13 to 15).map(lru.notifyGet(_)) + assert(lru.keysOfCachedEntries.equals(Seq(3, 10, 11, 12, 6, 8, 2, 13, 14, 15))) + assert(!kvGen.kvMap(7).isCached) + } + + test("FIFO policy") { + val kvGen = new IdentifyKVGen(15) + val cacheSize = 5 + val fifo = new FIFOCachePolicy[Int, TestValue]() + fifo.initialize(Array.empty[String], cacheSize, kvGen.loadFunc _, kvGen.evictionFunc _) + + // Load KVs 0-4. + (0 to 4).map(fifo.notifyGet(_)) + assert(fifo.keysOfCachedEntries.equals(Seq(0, 1, 2, 3, 4))) + + // Get 0-8, which should evict 0-3. + (0 to 8).map(fifo.notifyGet(_)) + assert(fifo.keysOfCachedEntries.equals(Seq(4, 5, 6, 7, 8))) + + // Remove 4, 6 and add 9-12. 5 and 7 should be evicted. + fifo.notifyRemove(4); fifo.notifyRemove(6) + (9 to 12).map(fifo.notifyGet(_)) + assert(fifo.keysOfCachedEntries.equals(Seq(8, 9, 10, 11, 12))) + } + + test("Policy classes instantiated from a string, with maxSize argument") { + val kvGen = new IdentifyKVGen(15) + val lruStr = "shark.memstore2.LRUCachePolicy(5)" + val lru = CachePolicy.instantiateWithUserSpecs( + lruStr, fallbackMaxSize = 10, kvGen.loadFunc _, kvGen.evictionFunc _) + assert(lru.maxSize == 5) + val fifoStr = "shark.memstore2.FIFOCachePolicy(5)" + val fifo = CachePolicy.instantiateWithUserSpecs( + fifoStr, fallbackMaxSize = 10, kvGen.loadFunc _, kvGen.evictionFunc _) + assert(fifo.maxSize == 5) + } + + test("Cache stats are recorded") { + val kvGen = new IdentifyKVGen(20) + val cacheSize = 5 + val lru = new LRUCachePolicy[Int, TestValue]() + lru.initialize(Array.empty[String], cacheSize, kvGen.loadFunc _, kvGen.evictionFunc _) + + // Hit rate should start at 1.0 + assert(lru.hitRate == 1.0) + + (0 to 4).map(lru.notifyGet(_)) + assert(lru.hitRate == 0.0) + + // Get 1, 2, 3, which should bring the hit rate to 0.375. + (1 to 3).map(lru.notifyGet(_)) + assert(lru.hitRate == 0.375) + + // Get 2-5, which brings the hit rate up to 0.50. + (2 to 5).map(lru.notifyGet(_)) + assert(lru.evictionCount == 1) + assert(lru.hitRate == 0.50) + } +} \ No newline at end of file diff --git a/src/test/scala/shark/memstore2/ColumnStatsSuite.scala b/src/test/scala/shark/memstore2/ColumnStatsSuite.scala index 21e55b21..2da1959c 100644 --- a/src/test/scala/shark/memstore2/ColumnStatsSuite.scala +++ b/src/test/scala/shark/memstore2/ColumnStatsSuite.scala @@ -18,6 +18,7 @@ package shark.memstore2 import java.sql.Timestamp +import scala.language.implicitConversions import org.apache.hadoop.io.Text @@ -54,7 +55,7 @@ class ColumnStatsSuite extends FunSuite { } test("ByteColumnStats") { - var c = new ColumnStats.ByteColumnStats + val c = new ColumnStats.ByteColumnStats c.append(0) assert(c.min == 0 && c.max == 0) c.append(1) @@ -72,7 +73,7 @@ class ColumnStatsSuite extends FunSuite { } test("ShortColumnStats") { - var c = new ColumnStats.ShortColumnStats + val c = new ColumnStats.ShortColumnStats c.append(0) assert(c.min == 0 && c.max == 0) c.append(1) @@ -123,7 +124,7 @@ class ColumnStatsSuite extends FunSuite { } test("LongColumnStats") { - var c = new ColumnStats.LongColumnStats + val c = new ColumnStats.LongColumnStats c.append(0) assert(c.min == 0 && c.max == 0) c.append(1) @@ -140,7 +141,7 @@ class ColumnStatsSuite extends FunSuite { } test("FloatColumnStats") { - var c = new ColumnStats.FloatColumnStats + val c = new ColumnStats.FloatColumnStats c.append(0) assert(c.min == 0 && c.max == 0) c.append(1) @@ -157,7 +158,7 @@ class ColumnStatsSuite extends FunSuite { } test("DoubleColumnStats") { - var c = new ColumnStats.DoubleColumnStats + val c = new ColumnStats.DoubleColumnStats c.append(0) assert(c.min == 0 && c.max == 0) c.append(1) @@ -174,7 +175,7 @@ class ColumnStatsSuite extends FunSuite { } test("TimestampColumnStats") { - var c = new ColumnStats.TimestampColumnStats + val c = new ColumnStats.TimestampColumnStats val ts1 = new Timestamp(1000) val ts2 = new Timestamp(2000) val ts3 = new Timestamp(1500) @@ -197,8 +198,13 @@ class ColumnStatsSuite extends FunSuite { test("StringColumnStats") { implicit def T(str: String): Text = new Text(str) - var c = new ColumnStats.StringColumnStats + val c = new ColumnStats.StringColumnStats assert(c.min == null && c.max == null) + + assert(!(c :> "test")) + assert(!(c :< "test")) + assert(!(c == "test")) + c.append("a") assert(c.min.equals(T("a")) && c.max.equals(T("a"))) diff --git a/src/test/scala/shark/memstore2/TablePartitionSuite.scala b/src/test/scala/shark/memstore2/TablePartitionSuite.scala index 047d4071..843cb1b1 100644 --- a/src/test/scala/shark/memstore2/TablePartitionSuite.scala +++ b/src/test/scala/shark/memstore2/TablePartitionSuite.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.scalatest.FunSuite +import org.apache.spark.SparkConf import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} @@ -31,7 +32,7 @@ class TablePartitionSuite extends FunSuite { val col2 = Array[Byte](1, 2, 3) val tp = new TablePartition(3, Array(ByteBuffer.wrap(col1), ByteBuffer.wrap(col2))) - val ser = new JavaSerializer + val ser = new JavaSerializer(new SparkConf(false)) val bytes = ser.newInstance().serialize(tp) val tp1 = ser.newInstance().deserialize[TablePartition](bytes) assert(tp1.numRows === 3) @@ -58,7 +59,7 @@ class TablePartitionSuite extends FunSuite { col2.rewind() val tp = new TablePartition(3, Array(col1, col2)) - val ser = new JavaSerializer + val ser = new JavaSerializer(new SparkConf(false)) val bytes = ser.newInstance().serialize(tp) val tp1 = ser.newInstance().deserialize[TablePartition](bytes) assert(tp1.numRows === 3) @@ -77,7 +78,7 @@ class TablePartitionSuite extends FunSuite { val col2 = Array[Byte](1, 2, 3) val tp = new TablePartition(3, Array(ByteBuffer.wrap(col1), ByteBuffer.wrap(col2))) - val ser = new KryoSerializer + val ser = new KryoSerializer(new SparkConf(false)) val bytes = ser.newInstance().serialize(tp) val tp1 = ser.newInstance().deserialize[TablePartition](bytes) assert(tp1.numRows === 3) @@ -104,7 +105,7 @@ class TablePartitionSuite extends FunSuite { col2.rewind() val tp = new TablePartition(3, Array(col1, col2)) - val ser = new KryoSerializer + val ser = new KryoSerializer(new SparkConf(false)) val bytes = ser.newInstance().serialize(tp) val tp1 = ser.newInstance().deserialize[TablePartition](bytes) assert(tp1.numRows === 3) diff --git a/src/test/scala/shark/memstore2/column/ColumnTypeSuite.scala b/src/test/scala/shark/memstore2/column/ColumnTypeSuite.scala index ec959bf7..1ea2f7a6 100644 --- a/src/test/scala/shark/memstore2/column/ColumnTypeSuite.scala +++ b/src/test/scala/shark/memstore2/column/ColumnTypeSuite.scala @@ -1,11 +1,30 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column -import org.scalatest.FunSuite import java.nio.ByteBuffer + import org.apache.hadoop.io.IntWritable import org.apache.hadoop.io.LongWritable import org.apache.hadoop.hive.serde2.io._ +import org.scalatest.FunSuite + class ColumnTypeSuite extends FunSuite { test("Int") { @@ -14,30 +33,30 @@ class ColumnTypeSuite extends FunSuite { var a: Seq[Int] = Array[Int](35, 67, 899, 4569001) a.foreach {i => buffer.putInt(i)} buffer.rewind() - a.foreach {i => - val v = INT.extract(buffer.position(), buffer) + a.foreach {i => + val v = INT.extract(buffer) assert(v == i) } buffer = ByteBuffer.allocate(32) a = Range(0, 4) - a.foreach { i => - INT.append(i, buffer) + a.foreach { i => + INT.append(i, buffer) } buffer.rewind() a.foreach { i => assert(buffer.getInt() == i)} - + buffer = ByteBuffer.allocate(32) a =Range(0,4) a.foreach { i => buffer.putInt(i)} buffer.rewind() val writable = new IntWritable() - a.foreach { i => - INT.extractInto(buffer.position(), buffer, writable) + a.foreach { i => + INT.extractInto(buffer, writable) assert(writable.get == i) } - + } - + test("Short") { assert(SHORT.defaultSize == 2) assert(SHORT.actualSize(8) == 2) @@ -45,30 +64,30 @@ class ColumnTypeSuite extends FunSuite { var a = Array[Short](35, 67, 87, 45) a.foreach {i => buffer.putShort(i)} buffer.rewind() - a.foreach {i => - val v = SHORT.extract(buffer.position(), buffer) + a.foreach {i => + val v = SHORT.extract(buffer) assert(v == i) } - + buffer = ByteBuffer.allocate(32) a = Array[Short](0,1,2,3) - a.foreach { i => - SHORT.append(i, buffer) + a.foreach { i => + SHORT.append(i, buffer) } buffer.rewind() a.foreach { i => assert(buffer.getShort() == i)} - + buffer = ByteBuffer.allocate(32) a =Array[Short](0,1,2,3) a.foreach { i => buffer.putShort(i)} buffer.rewind() val writable = new ShortWritable() - a.foreach { i => - SHORT.extractInto(buffer.position(), buffer, writable) + a.foreach { i => + SHORT.extractInto(buffer, writable) assert(writable.get == i) } } - + test("Long") { assert(LONG.defaultSize == 8) assert(LONG.actualSize(45L) == 8) @@ -76,26 +95,26 @@ class ColumnTypeSuite extends FunSuite { var a = Array[Long](35L, 67L, 8799000880L, 45000999090L) a.foreach {i => buffer.putLong(i)} buffer.rewind() - a.foreach {i => - val v = LONG.extract(buffer.position(), buffer) + a.foreach {i => + val v = LONG.extract(buffer) assert(v == i) } - + buffer = ByteBuffer.allocate(32) a = Array[Long](0,1,2,3) - a.foreach { i => - LONG.append(i, buffer) + a.foreach { i => + LONG.append(i, buffer) } buffer.rewind() a.foreach { i => assert(buffer.getLong() == i)} - + buffer = ByteBuffer.allocate(32) a =Array[Long](0,1,2,3) a.foreach { i => buffer.putLong(i)} buffer.rewind() val writable = new LongWritable() - a.foreach { i => - LONG.extractInto(buffer.position(), buffer, writable) + a.foreach { i => + LONG.extractInto(buffer, writable) assert(writable.get == i) } } diff --git a/src/test/scala/shark/memstore2/column/CompressedColumnIteratorSuite.scala b/src/test/scala/shark/memstore2/column/CompressedColumnIteratorSuite.scala index 322de1e1..6ed0aa4d 100644 --- a/src/test/scala/shark/memstore2/column/CompressedColumnIteratorSuite.scala +++ b/src/test/scala/shark/memstore2/column/CompressedColumnIteratorSuite.scala @@ -1,3 +1,20 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column import java.nio.ByteBuffer @@ -5,61 +22,155 @@ import java.nio.ByteOrder import org.scalatest.FunSuite import org.apache.hadoop.io.Text +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector import shark.memstore2.column.Implicits._ class CompressedColumnIteratorSuite extends FunSuite { - - test("RLE Decompression") { - val b = ByteBuffer.allocate(1024) + + /** + * Generic tester across types and encodings. The function applies the given compression + * algorithm on the given sequence of values, and test whether the resulting iterator gives + * the same sequence of values. + * + * If we expect the compression algorithm to not compress the data, we should set the + * shouldNotCompress flag to true. This way, it doesn't actually create a compressed buffer, + * but simply tests the compression ratio returned by the algorithm is >= 1.0. + */ + def testList[T, W]( + l: Seq[T], + t: ColumnType[T, _], + algo: CompressionAlgorithm, + expectedCompressedSize: Int, + shouldNotCompress: Boolean = false) + { + val b = ByteBuffer.allocate(1024 + (3 * 40 * l.size)) b.order(ByteOrder.nativeOrder()) - b.putInt(STRING.typeID) - val rle = new RLE() - - Array(new Text("abc"), new Text("abc"), new Text("efg"), new Text("abc")).foreach { text => - STRING.append(text, b) - rle.gatherStatsForCompressibility(text, STRING) + b.putInt(t.typeID) + l.foreach { item => + t.append(item, b) + algo.gatherStatsForCompressibility(item, t.asInstanceOf[ColumnType[Any, _]]) } b.limit(b.position()) b.rewind() - val compressedBuffer = rle.compress(b, STRING) - val iter = new TestIterator(compressedBuffer, compressedBuffer.getInt()) - iter.next() - assert(iter.current.toString().equals("abc")) - iter.next() - assert(iter.current.toString().equals("abc")) - assert(iter.current.toString().equals("abc")) - iter.next() - assert(iter.current.toString().equals("efg")) - iter.next() - assert(iter.current.toString().equals("abc")) - } - - test("Dictionary Decompression") { - val b = ByteBuffer.allocate(1024) - b.order(ByteOrder.nativeOrder()) - b.putInt(STRING.typeID) - val dict = new DictionaryEncoding() - - Array(new Text("abc"), new Text("abc"), new Text("efg"), new Text("abc")).foreach { text => - STRING.append(text, b) - dict.gatherStatsForCompressibility(text, STRING) + + info("compressed size: %d, uncompressed size: %d, compression ratio %f".format( + algo.compressedSize, algo.uncompressedSize, algo.compressionRatio)) + + assert(algo.compressedSize === expectedCompressedSize) + + if (shouldNotCompress) { + assert(algo.compressionRatio >= 1.0) + } else { + val compressedBuffer = algo.compress(b, t) + val iter = new TestIterator(compressedBuffer, compressedBuffer.getInt()) + + val oi: ObjectInspector = t match { + case BOOLEAN => PrimitiveObjectInspectorFactory.writableBooleanObjectInspector + case BYTE => PrimitiveObjectInspectorFactory.writableByteObjectInspector + case SHORT => PrimitiveObjectInspectorFactory.writableShortObjectInspector + case INT => PrimitiveObjectInspectorFactory.writableIntObjectInspector + case LONG => PrimitiveObjectInspectorFactory.writableLongObjectInspector + case STRING => PrimitiveObjectInspectorFactory.writableStringObjectInspector + case _ => throw new UnsupportedOperationException("Unsupported compression type " + t) + } + + l.foreach { x => + iter.next() + assert(t.get(iter.current, oi) === x) + } + + // Make sure we reach the end of the iterator. + assert(!iter.hasNext) } - b.limit(b.position()) - b.rewind() - val compressedBuffer = dict.compress(b, STRING) - val iter = new TestIterator(compressedBuffer, compressedBuffer.getInt()) - iter.next() - assert(iter.current.toString().equals("abc")) - iter.next() - assert(iter.current.toString().equals("abc")) - assert(iter.current.toString().equals("abc")) - iter.next() - assert(iter.current.toString().equals("efg")) - iter.next() - assert(iter.current.toString().equals("abc")) + } + + test("RLE Boolean") { + // 3 runs: (1+4)*3 + val bools = Seq(true, true, false, true, true, true, true, true, true, true, true, true) + testList(bools, BOOLEAN, new RLE, 15) + } + + test("RLE Byte") { + // 3 runs: (1+4)*3 + testList(Seq[Byte](10, 10, 10, 10, 10, 10, 10, 10, 10, 20, 10), BYTE, new RLE, 15) + } + + test("RLE Short") { + // 3 runs: (2+4)*3 + testList(Seq[Short](10, 10, 10, 20000, 20000, 20000, 500, 500, 500, 500), SHORT, new RLE, 18) + } + + test("RLE Int") { + // 3 runs: (4+4)*3 + testList(Seq[Int](1000000, 1000000, 1000000, 1000000, 900000, 99), INT, new RLE, 24) + } + + test("RLE Long") { + // 2 runs: (8+4)*3 + val longs = Seq[Long](2147483649L, 2147483649L, 2147483649L, 2147483649L, 500L, 500L, 500L) + testList(longs, LONG, new RLE, 24) + } + + test("RLE String") { + // 3 runs: (4+4+4) + (4+1+4) + (4+1+4) = 30 + val strs: Seq[Text] = Seq("abcd", "abcd", "abcd", "e", "e", "!", "!").map(s => new Text(s)) + testList(strs, STRING, new RLE, 30) + } + + test("Dictionary Encoded Int") { + // dict len + 3 distinct values + 7 values = 4 + 3*4 + 7*2 = 30 + val ints = Seq[Int](1000000, 1000000, 99, 1000000, 1000000, 900000, 99) + testList(ints, INT, new DictionaryEncoding, 30) + } + + test("Dictionary Encoded Long") { + // dict len + 2 distinct values + 7 values = 4 + 2*8 + 7*2 = 34 + val longs = Seq[Long](2147483649L, 2147483649L, 2147483649L, 2147483649L, 500L, 500L, 500L) + testList(longs, LONG, new DictionaryEncoding, 34) + } + + test("Dictionary Encoded String") { + // dict len + 3 distinct values + 8 values = 4 + (4+4) + (4+1) + (4+1) + 8*2 = + val strs: Seq[Text] = Seq("abcd", "abcd", "abcd", "e", "e", "e", "!", "!").map(s => new Text(s)) + testList(strs, STRING, new DictionaryEncoding, 38, shouldNotCompress = false) + } + + test("Dictionary Encoding at limit of unique values") { + val ints = Range(0, Short.MaxValue - 1).flatMap(i => Iterator(i, i, i)) + val expectedLen = 4 + (Short.MaxValue - 1) * 4 + 2 * (Short.MaxValue - 1) * 3 + testList(ints, INT, new DictionaryEncoding, expectedLen) + } + + test("Dictionary Encoding - should not compress") { + val ints = Range(0, Short.MaxValue.toInt) + testList(ints, INT, new DictionaryEncoding, Int.MaxValue, shouldNotCompress = true) + } + + test("RLE - should not compress") { + val ints = Range(0, Short.MaxValue.toInt + 1) + val expectedLen = (Short.MaxValue.toInt + 1) * (4 + 4) + testList(ints, INT, new RLE, expectedLen, shouldNotCompress = true) + } + + test("BooleanBitSet Boolean (shorter)") { + // 1 Long worth of Booleans, in addtion to the length field: 4+8 + val bools = Seq(true, true, false, false) + testList(bools, BOOLEAN, new BooleanBitSetCompression, 4+8) + } + + test("BooleanBitSet Boolean (longer)") { + // 2 Longs worth of Booleans, in addtion to the length field: 4+8+8 + val bools = Seq(true, true, false, false, true, true, false, false,true, true, false, false,true, true, false, false, + true, true, false, false,true, true, false, false, true, true, false, false,true, true, false, false, + true, true, false, false,true, true, false, false, true, true, false, false,true, true, false, false, + true, true, false, false,true, true, false, false, true, true, false, false,true, true, false, false, + true, true, false, false,true, true, false, false, true, true, false, false,true, true, false, false) + testList(bools, BOOLEAN, new BooleanBitSetCompression, 4+8+8) } } - class TestIterator(val buffer: ByteBuffer, val columnType: ColumnType[_,_]) - extends CompressedColumnIterator + +class TestIterator(val buffer: ByteBuffer, val columnType: ColumnType[_,_]) + extends CompressedColumnIterator diff --git a/src/test/scala/shark/memstore2/column/CompressionAlgorithmSuite.scala b/src/test/scala/shark/memstore2/column/CompressionAlgorithmSuite.scala index 83d3b717..57eab675 100644 --- a/src/test/scala/shark/memstore2/column/CompressionAlgorithmSuite.scala +++ b/src/test/scala/shark/memstore2/column/CompressionAlgorithmSuite.scala @@ -1,23 +1,45 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column -import java.nio.ByteBuffer -import java.nio.ByteOrder +import java.nio.{ByteBuffer, ByteOrder} + import scala.collection.mutable.HashMap -import org.scalatest.FunSuite import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.Text +import org.scalatest.FunSuite + import shark.memstore2.column.ColumnStats._ class CompressionAlgorithmSuite extends FunSuite { - test("Compressed Column Builder") { + // TODO: clean these tests. + + test("CompressedColumnBuilder using RLE") { + class TestColumnBuilder(val stats: ColumnStats[Int], val t: ColumnType[Int,_]) - extends CompressedColumnBuilder[Int] { + extends CompressedColumnBuilder[Int] { compressionSchemes = Seq(new RLE()) override def shouldApply(scheme: CompressionAlgorithm) = true } + val b = new TestColumnBuilder(new NoOpStats, INT) b.initialize(100) val oi = PrimitiveObjectInspectorFactory.javaIntObjectInspector @@ -31,7 +53,9 @@ class CompressionAlgorithmSuite extends FunSuite { assert(compressedBuffer.getInt() == 123) assert(compressedBuffer.getInt() == 2) - + assert(compressedBuffer.getInt() == 56) + assert(compressedBuffer.getInt() == 2) + assert(!compressedBuffer.hasRemaining) } test("RLE Strings") { @@ -39,10 +63,7 @@ class CompressionAlgorithmSuite extends FunSuite { b.order(ByteOrder.nativeOrder()) b.putInt(STRING.typeID) val rle = new RLE() - Array[Text](new Text("abc"), - new Text("abc"), - new Text("efg"), - new Text("abc")).foreach { text => + Seq[Text](new Text("abc"), new Text("abc"), new Text("efg"), new Text("abc")).foreach { text => STRING.append(text, b) rle.gatherStatsForCompressibility(text, STRING) } @@ -51,13 +72,16 @@ class CompressionAlgorithmSuite extends FunSuite { val compressedBuffer = rle.compress(b, STRING) assert(compressedBuffer.getInt() == STRING.typeID) assert(compressedBuffer.getInt() == RLECompressionType.typeID) - assert(STRING.extract(compressedBuffer.position(), compressedBuffer).equals (new Text("abc"))) + assert(STRING.extract(compressedBuffer).equals(new Text("abc"))) assert(compressedBuffer.getInt() == 2) - assert(STRING.extract(compressedBuffer.position(), compressedBuffer).equals (new Text("efg"))) + assert(STRING.extract(compressedBuffer).equals(new Text("efg"))) + assert(compressedBuffer.getInt() == 1) + assert(STRING.extract(compressedBuffer).equals(new Text("abc"))) assert(compressedBuffer.getInt() == 1) + assert(!compressedBuffer.hasRemaining) } - test("RLE no encoding") { + test("RLE int with run length 1") { val b = ByteBuffer.allocate(16) b.order(ByteOrder.nativeOrder()) b.putInt(INT.typeID) @@ -75,14 +99,15 @@ class CompressionAlgorithmSuite extends FunSuite { assert(compressedBuffer.getInt() == 1) assert(compressedBuffer.getInt() == 56) assert(compressedBuffer.getInt() == 1) + assert(!compressedBuffer.hasRemaining) } - - test("RLE perfect encoding") { + + test("RLE int single run") { val b = ByteBuffer.allocate(4008) b.order(ByteOrder.nativeOrder()) b.putInt(INT.typeID) val rle = new RLE() - Range(0,1000).foreach { x => + Range(0, 1000).foreach { x => b.putInt(6) rle.gatherStatsForCompressibility(6, INT) } @@ -93,16 +118,36 @@ class CompressionAlgorithmSuite extends FunSuite { assert(compressedBuffer.getInt() == RLECompressionType.typeID) assert(compressedBuffer.getInt() == 6) assert(compressedBuffer.getInt() == 1000) + assert(!compressedBuffer.hasRemaining) } - - test("RLE mixture") { + + test("RLE long single run") { + val b = ByteBuffer.allocate(8008) + b.order(ByteOrder.nativeOrder()) + b.putInt(LONG.typeID) + val rle = new RLE() + Range(0, 1000).foreach { x => + b.putLong(Long.MaxValue - 6) + rle.gatherStatsForCompressibility(Long.MaxValue - 6, LONG) + } + b.limit(b.position()) + b.rewind() + val compressedBuffer = rle.compress(b, LONG) + assert(compressedBuffer.getInt() == LONG.typeID) + assert(compressedBuffer.getInt() == RLECompressionType.typeID) + assert(compressedBuffer.getLong() == Long.MaxValue - 6) + assert(compressedBuffer.getInt() == 1000) + assert(!compressedBuffer.hasRemaining) + } + + test("RLE int 3 runs") { val b = ByteBuffer.allocate(4008) b.order(ByteOrder.nativeOrder()) b.putInt(INT.typeID) val items = Array[Int](10, 20, 40) val rle = new RLE() - Range(0,1000).foreach { x => + Range(0, 1000).foreach { x => val v = if (x < 100) items(0) else if (x < 500) items(1) else items(2) b.putInt(v) rle.gatherStatsForCompressibility(v, INT) @@ -116,80 +161,131 @@ class CompressionAlgorithmSuite extends FunSuite { assert(compressedBuffer.getInt() == 100) assert(compressedBuffer.getInt() == 20) assert(compressedBuffer.getInt() == 400) + assert(compressedBuffer.getInt() == 40) + assert(compressedBuffer.getInt() == 500) + assert(!compressedBuffer.hasRemaining) } - - test("RLE perf") { + + test("RLE int single long run") { val b = ByteBuffer.allocate(4000008) b.order(ByteOrder.nativeOrder()) b.putInt(INT.typeID) val rle = new RLE() - Range(0,1000000).foreach { x => + Range(0, 1000000).foreach { x => b.putInt(6) rle.gatherStatsForCompressibility(6, INT) } b.limit(b.position()) b.rewind() val compressedBuffer = rle.compress(b, INT) - //first 4 bytes is the compression scheme assert(compressedBuffer.getInt() == RLECompressionType.typeID) assert(compressedBuffer.getInt() == INT.typeID) assert(compressedBuffer.getInt() == 6) assert(compressedBuffer.getInt() == 1000000) + assert(!compressedBuffer.hasRemaining) } - + test("Dictionary Encoding") { - val b = ByteBuffer.allocate(1024) - b.order(ByteOrder.nativeOrder()) - b.putInt(STRING.typeID) - val de = new DictionaryEncoding() - Array[Text](new Text("abc"), - new Text("abc"), - new Text("efg"), - new Text("abc")).foreach { text => - STRING.append(text, b) - de.gatherStatsForCompressibility(text, STRING) + + def testList[T]( + l: Seq[T], + u: ColumnType[T, _], + expectedDictSize: Int, + compareFunc: (T, T) => Boolean = (a: T, b: T) => a == b) { + + val b = ByteBuffer.allocate(1024 + (3*40*l.size)) + b.order(ByteOrder.nativeOrder()) + b.putInt(u.typeID) + val de = new DictionaryEncoding() + l.foreach { item => + u.append(item, b) + de.gatherStatsForCompressibility(item, u.asInstanceOf[ColumnType[Any, _]]) + } + b.limit(b.position()) + b.rewind() + val compressedBuffer = de.compress(b, u) + assert(compressedBuffer.getInt() === u.typeID) + assert(compressedBuffer.getInt() === DictionaryCompressionType.typeID) + assert(compressedBuffer.getInt() === expectedDictSize) //dictionary size + val dictionary = new HashMap[Short, T]() + var count = 0 + while (count < expectedDictSize) { + val v = u.extract(compressedBuffer) + dictionary.put(dictionary.size.toShort, u.clone(v)) + count += 1 + } + assert(dictionary.get(0).get.equals(l(0))) + assert(dictionary.get(1).get.equals(l(2))) + l.foreach { x => + val y = dictionary.get(compressedBuffer.getShort()).get + assert(compareFunc(y, x)) } - b.limit(b.position()) - b.rewind() - val compressedBuffer = de.compress(b, STRING) - assert(compressedBuffer.getInt() == STRING.typeID) - assert(compressedBuffer.getInt() == DictionaryCompressionType.typeID) - assert(compressedBuffer.getInt() == 2) //dictionary size - val dictionary = new HashMap[Int, Text]() - var count = 0 - while (count < 2) { - val v = STRING.extract(compressedBuffer.position(), compressedBuffer) - val index = compressedBuffer.getInt() - dictionary.put(index, v) - count += 1 } - assert(dictionary.get(0).get.equals(new Text("abc"))) - assert(dictionary.get(1).get.equals(new Text("efg"))) - //read the next 4 items - assert(compressedBuffer.getInt() == 0) - assert(compressedBuffer.getInt() == 0) - assert(compressedBuffer.getInt() == 1) - assert(compressedBuffer.getInt() == 0) + + val iList = Array[Int](10, 10, 20, 10) + val lList = iList.map { i => Long.MaxValue - i.toLong } + val sList = iList.map { i => new Text(i.toString) } + + testList(iList, INT, 2) + testList(lList, LONG, 2) + testList(sList, STRING, 2, (a: Text, b: Text) => a.hashCode == b.hashCode) + + // test at limit of unique values + val alternating = Range(0, Short.MaxValue-1, 1).flatMap { s => List(1, s) } + val longList = List.concat(iList, alternating, iList) + assert(longList.size === (8 + 2*(Short.MaxValue-1))) + testList(longList, INT, Short.MaxValue - 1) } - - test("RLE region") { + + test("Uncompressed text") { val b = new StringColumnBuilder b.initialize(0) val oi = PrimitiveObjectInspectorFactory.javaStringObjectInspector - val lines = Array[String]("lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to", + val lines = Array[String]( + "lar deposits. blithely final packages cajole. regular waters are final requests.", "hs use ironic, even requests. s", "ges. thinly even pinto beans ca", "ly final courts cajole furiously final excuse", - "uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl" + "uickly special accounts cajole carefully blithely close requests. carefully final" ) lines.foreach { line => b.append(line, oi) } val newBuffer = b.build() - assert(newBuffer.getInt() == STRING.typeID) - assert(newBuffer.getInt() == RLECompressionType.typeID) - + assert(newBuffer.getInt() === 0) // null count + assert(newBuffer.getInt() === STRING.typeID) + assert(newBuffer.getInt() === DefaultCompressionType.typeID) + } + + test("BooleanBitSet encoding") { + val bbs = new BooleanBitSetCompression() + val b = ByteBuffer.allocate(4 + 64 + 2) + b.order(ByteOrder.nativeOrder()) + b.putInt(BOOLEAN.typeID) + for(_ <- 1 to 5) { + b.put(0.toByte) + b.put(1.toByte) + bbs.gatherStatsForCompressibility(false, BOOLEAN) + bbs.gatherStatsForCompressibility(true, BOOLEAN) + } + for(_ <- 1 to 54) { + b.put(0.toByte) + bbs.gatherStatsForCompressibility(false, BOOLEAN) + } + b.put(0.toByte) + b.put(1.toByte) + bbs.gatherStatsForCompressibility(false, BOOLEAN) + bbs.gatherStatsForCompressibility(true, BOOLEAN) + b.limit(b.position()) + b.rewind() + val compressedBuffer = bbs.compress(b, BOOLEAN) + assert(compressedBuffer.getInt() === BOOLEAN.typeID) + assert(compressedBuffer.getInt() === BooleanBitSetCompressionType.typeID) + assert(compressedBuffer.getInt() === 64 + 2) + assert(compressedBuffer.getLong() === 682) + assert(compressedBuffer.getLong() === 2) + assert(!compressedBuffer.hasRemaining) } -} \ No newline at end of file +} diff --git a/src/test/scala/shark/memstore2/column/NullableColumnBuilderSuite.scala b/src/test/scala/shark/memstore2/column/NullableColumnBuilderSuite.scala index d2c6a0dc..92e58760 100644 --- a/src/test/scala/shark/memstore2/column/NullableColumnBuilderSuite.scala +++ b/src/test/scala/shark/memstore2/column/NullableColumnBuilderSuite.scala @@ -1,28 +1,42 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column -import org.scalatest.FunSuite import org.apache.hadoop.io.Text import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.scalatest.FunSuite class NullableColumnBuilderSuite extends FunSuite { - test("Perf") { - val c = new StringColumnBuilder() - c.initialize(1024*1024*8) - val oi = PrimitiveObjectInspectorFactory.writableStringObjectInspector - Range(0, 1000000).foreach { i => - c.append(new Text("00000000000000000000000000000000" + i), oi) - } + test("Empty column") { + val c = new IntColumnBuilder() + c.initialize(4) val b = c.build() - val i = ColumnIterator.newIterator(b) - Range(0, 1000000).foreach { x => - i.next() - i.current - } + // # of nulls + assert(b.getInt() === 0) + // column type + assert(b.getInt() === INT.typeID) + assert(b.getInt() === DefaultCompressionType.typeID) + assert(!b.hasRemaining) } - test("Grow") { + test("Buffer size auto growth") { val c = new StringColumnBuilder() c.initialize(4) val oi = PrimitiveObjectInspectorFactory.writableStringObjectInspector @@ -35,8 +49,9 @@ class NullableColumnBuilderSuite extends FunSuite { c.append(null, oi) c.append(new Text("efg"), oi) val b = c.build() + b.position(4 + 4 * 4) val colType = b.getInt() - assert(colType == STRING.typeID) + assert(colType === STRING.typeID) } test("Null Strings") { @@ -48,22 +63,27 @@ class NullableColumnBuilderSuite extends FunSuite { c.append(new Text("b"), oi) c.append(null, oi) val b = c.build() - //expect first element is col type - assert(b.getInt() == STRING.typeID) - //next comes # of nulls - assert(b.getInt() == 2) - //typeID of first null is 1, that of second null is 3 - assert(b.getInt() == 1) - assert(b.getInt() == 3) - - //next comes the compression type - assert(b.getInt() == -1) - assert(b.getInt() == 1) - assert(b.get() == 97) - assert(b.getInt() == 1) - assert(b.get() == 98) + + // Number of nulls + assert(b.getInt() === 2) + + // First null position is 1, and then 3 + assert(b.getInt() === 1) + assert(b.getInt() === 3) + + // Column data type + assert(b.getInt() === STRING.typeID) + + // Compression type + assert(b.getInt() === DefaultCompressionType.typeID) + + // Data + assert(b.getInt() === 1) + assert(b.get() === 97) + assert(b.getInt() === 1) + assert(b.get() === 98) } - + test("Null Ints") { val c = new IntColumnBuilder() c.initialize(4) @@ -73,17 +93,34 @@ class NullableColumnBuilderSuite extends FunSuite { c.append(null, oi) c.append(56.asInstanceOf[Object], oi) val b = c.build() - //expect first element is col type - assert(b.getInt() == INT.typeID) - //next comes # of nulls - assert(b.getInt() == 2) - //typeID of first null is 1, that of second null is 3 - assert(b.getInt() == 1) - assert(b.getInt() == 2) - assert(b.getInt() == -1) - assert(b.getInt() == 123) + + // # of nulls and null positions + assert(b.getInt() === 2) + assert(b.getInt() === 1) + assert(b.getInt() === 2) + + // non nulls + assert(b.getInt() === INT.typeID) + assert(b.getInt() === DefaultCompressionType.typeID) + assert(b.getInt() === 123) + } + + test("Nullable Ints 2") { + val c = new IntColumnBuilder() + c.initialize(4) + val oi = PrimitiveObjectInspectorFactory.javaIntObjectInspector + Range(1, 1000).foreach { x => + c.append(x.asInstanceOf[Object], oi) + } + val b = c.build() + // null count + assert(b.getInt() === 0) + // column type + assert(b.getInt() === INT.typeID) + // compression type + assert(b.getInt() === DefaultCompressionType.typeID) } - + test("Null Longs") { val c = new LongColumnBuilder() c.initialize(4) @@ -93,26 +130,16 @@ class NullableColumnBuilderSuite extends FunSuite { c.append(null, oi) c.append(56L.asInstanceOf[Object], oi) val b = c.build() - //expect first element is col type - assert(b.getInt() == LONG.typeID) - //next comes # of nulls - assert(b.getInt() == 2) - //typeID of first null is 1, that of second null is 3 - assert(b.getInt() == 1) - assert(b.getInt() == 2) - assert(b.getInt() == -1) - assert(b.getLong() == 123L) - } - - test("Trigger RLE") { - val c = new IntColumnBuilder() - c.initialize(4) - val oi = PrimitiveObjectInspectorFactory.javaIntObjectInspector - Range(1,1000).foreach { x => - c.append(x.asInstanceOf[Object], oi) - } - val b = c.build() - assert(b.getInt() == INT.typeID) - assert(b.getInt() == RLECompressionType.typeID) + + // # of nulls and null positions + assert(b.getInt() === 2) + assert(b.getInt() === 1) + assert(b.getInt() === 2) + + // non-nulls + assert(b.getInt() === LONG.typeID) + assert(b.getInt() === DefaultCompressionType.typeID) + assert(b.getLong() === 123L) } -} \ No newline at end of file + +} diff --git a/src/test/scala/shark/memstore2/column/NullableColumnIteratorSuite.scala b/src/test/scala/shark/memstore2/column/NullableColumnIteratorSuite.scala index b5d210eb..614fc625 100644 --- a/src/test/scala/shark/memstore2/column/NullableColumnIteratorSuite.scala +++ b/src/test/scala/shark/memstore2/column/NullableColumnIteratorSuite.scala @@ -1,3 +1,20 @@ +/* + * Copyright (C) 2012 The Regents of The University California. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package shark.memstore2.column import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory @@ -25,6 +42,7 @@ class NullableColumnIteratorSuite extends FunSuite { val b = c.build() val i = ColumnIterator.newIterator(b) Range(0, a.length).foreach { x => + if (x > 0) assert(i.hasNext) i.next() val v = i.current if (a(x) == null) { @@ -33,7 +51,9 @@ class NullableColumnIteratorSuite extends FunSuite { assert(v.toString == a(x).toString) } } + assert(!i.hasNext) } + test("Iterate Strings") { val oi = PrimitiveObjectInspectorFactory.writableStringObjectInspector val c = ColumnBuilder.create(oi) @@ -58,23 +78,35 @@ class NullableColumnIteratorSuite extends FunSuite { assert(i.current.toString() == "Abcdz") i.next() assert(i.current == null) + assert(false === i.hasNext) } - + test("Iterate Ints") { - val oi = PrimitiveObjectInspectorFactory.javaIntObjectInspector - val c = ColumnBuilder.create(oi) - c.initialize(4) - c.append(123.asInstanceOf[Object],oi) - c.append(null, oi) - c.append(null, oi) - c.append(56.asInstanceOf[Object], oi) - val b = c.build() - val i = ColumnIterator.newIterator(b) - i.next() - assert(i.current.asInstanceOf[IntWritable].get() == 123) - i.next() - assert(i.current == null) - i.next() - assert(i.current == null) + def testList(l: Seq[AnyRef]) { + val oi = PrimitiveObjectInspectorFactory.javaIntObjectInspector + val c = ColumnBuilder.create(oi) + c.initialize(l.size) + + l.foreach { item => + c.append(item, oi) + } + + val b = c.build() + val i = ColumnIterator.newIterator(b) + + l.foreach { x => + i.next() + if (x == null) { + assert(i.current === x) + } else { + assert(i.current.asInstanceOf[IntWritable].get === x) + } + } + assert(false === i.hasNext) + } + + testList(List(null, null, 123.asInstanceOf[AnyRef])) + testList(List(123.asInstanceOf[AnyRef], 4.asInstanceOf[AnyRef], null)) + testList(List(null)) } -} \ No newline at end of file +} diff --git a/src/test/scala/shark/util/BloomFilterSuite.scala b/src/test/scala/shark/util/BloomFilterSuite.scala index 6126650f..31171d7e 100644 --- a/src/test/scala/shark/util/BloomFilterSuite.scala +++ b/src/test/scala/shark/util/BloomFilterSuite.scala @@ -5,13 +5,13 @@ import org.scalatest.FunSuite class BloomFilterSuite extends FunSuite{ test("Integer") { - val bf = new BloomFilter(0.03,1000000) - Range(0,1000000).foreach { + val bf = new BloomFilter(0.03, 1000000) + Range(0, 1000000).foreach { i => bf.add(i) } assert(bf.contains(333)) assert(bf.contains(678)) - assert(bf.contains(1200000) == false) + assert(!bf.contains(1200000)) } test("Integer FP") { @@ -26,12 +26,10 @@ class BloomFilterSuite extends FunSuite{ i => bf.contains(i*10) } val s = e.groupBy(x => x).map(x => (x._1, x._2.size)) - println(s) val t = s(true) val f = s(false) assert(f > 25 && f < 35) assert(t < 75 && t > 65) // expect false positive to be < 3 % and no false negatives - } } \ No newline at end of file