Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/master' into SPARK-8536-LDA-asym…
Browse files Browse the repository at this point in the history
…metric-priors

* apache/master:
  [SPARK-9224] [MLLIB] OnlineLDA Performance Improvements
  [SPARK-9024] Unsafe HashJoin/HashOuterJoin/HashSemiJoin
  [SPARK-9165] [SQL] codegen for CreateArray, CreateStruct and CreateNamedStruct
  [SPARK-9082] [SQL] Filter using non-deterministic expressions should not be pushed down
  [SPARK-9254] [BUILD] [HOTFIX] sbt-launch-lib.bash should support HTTP/HTTPS redirection
  [SPARK-4233] [SPARK-4367] [SPARK-3947] [SPARK-3056] [SQL] Aggregation Improvement
  [SPARK-9232] [SQL] Duplicate code in JSONRelation
  [SPARK-9121] [SPARKR] Get rid of the warnings about `no visible global function definition` in SparkR
  [SPARK-9154][SQL] Rename formatString to format_string.
  [SPARK-9154] [SQL] codegen StringFormat
  [SPARK-9206] [SQL] Fix HiveContext classloading for GCS connector.
  [SPARK-8906][SQL] Move all internal data source classes into execution.datasources.
  [SPARK-8357] Fix unsafe memory leak on empty inputs in GeneratedAggregate
  Revert "[SPARK-9154] [SQL] codegen StringFormat"
  • Loading branch information
Feynman Liang committed Jul 22, 2015
2 parents 58f1d7b + 8486cd8 commit ef5821d
Show file tree
Hide file tree
Showing 100 changed files with 3,966 additions and 411 deletions.
8 changes: 6 additions & 2 deletions build/sbt-launch-lib.bash
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,13 @@ acquire_sbt_jar () {
printf "Attempting to fetch sbt\n"
JAR_DL="${JAR}.part"
if [ $(command -v curl) ]; then
(curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
(curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\
(rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\
mv "${JAR_DL}" "${JAR}"
elif [ $(command -v wget) ]; then
(wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
(wget --quiet ${URL1} -O "${JAR_DL}" ||\
(rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\
mv "${JAR_DL}" "${JAR}"
else
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
exit -1
Expand Down
12 changes: 9 additions & 3 deletions dev/lint-r.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@
# limitations under the License.
#

argv <- commandArgs(TRUE)
SPARK_ROOT_DIR <- as.character(argv[1])

# Installs lintr from Github.
# NOTE: The CRAN's version is too old to adapt to our rules.
if ("lintr" %in% row.names(installed.packages()) == FALSE) {
devtools::install_github("jimhester/lintr")
}
library(lintr)

argv <- commandArgs(TRUE)
SPARK_ROOT_DIR <- as.character(argv[1])
library(lintr)
library(methods)
library(testthat)
if (! library(SparkR, lib.loc = file.path(SPARK_ROOT_DIR, "R", "lib"), logical.return = TRUE)) {
stop("You should install SparkR in a local directory with `R/install-dev.sh`.")
}

path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg")
lint_package(path.to.package, cache = FALSE)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector}
import org.apache.spark.rdd.RDD

/**
Expand Down Expand Up @@ -385,7 +385,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
iteration += 1
val k = this.k
val vocabSize = this.vocabSize
val Elogbeta = dirichletExpectation(lambda)
val Elogbeta = dirichletExpectation(lambda).t
val expElogbeta = exp(Elogbeta)
val alpha = this.alpha.toBreeze
val gammaShape = this.gammaShape
Expand All @@ -400,41 +400,36 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
case v => throw new IllegalArgumentException("Online LDA does not support vector type "
+ v.getClass)
}
if (!ids.isEmpty) {

// Initialize the variational distribution q(theta|gamma) for the mini-batch
val gammad: BDV[Double] =
new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K
val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K

val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
var meanchange = 1D
val ctsVector = new BDV[Double](cts) // ids

// Iterate between gamma and phi until convergence
while (meanchange > 1e-3) {
val lastgamma = gammad.copy
// K K * ids ids
gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
expElogthetad := exp(digamma(gammad) - digamma(sum(gammad)))
phinorm := expElogbetad * expElogthetad :+ 1e-100
meanchange = sum(abs(gammad - lastgamma)) / k
}

// Initialize the variational distribution q(theta|gamma) for the mini-batch
var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K
var expElogthetad = exp(Elogthetad) // 1 * K
val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids

var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
var meanchange = 1D
val ctsVector = new BDV[Double](cts).t // 1 * ids

// Iterate between gamma and phi until convergence
while (meanchange > 1e-3) {
val lastgamma = gammad
// 1*K 1 * ids ids * k
gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) :+ alpha.t
Elogthetad = digamma(gammad) - digamma(sum(gammad))
expElogthetad = exp(Elogthetad)
phinorm = expElogthetad * expElogbetad + 1e-100
meanchange = sum(abs(gammad - lastgamma)) / k
}

val m1 = expElogthetad.t
val m2 = (ctsVector / phinorm).t.toDenseVector
var i = 0
while (i < ids.size) {
stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
i += 1
stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
}
}
Iterator(stat)
}

val statsSum: BDM[Double] = stats.reduce(_ += _)
val batchResult = statsSum :* expElogbeta
val batchResult = statsSum :* expElogbeta.t

// Note that this is an optimization to avoid batch.count
update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
Expand Down
47 changes: 47 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,53 @@ object MimaExcludes {
// SPARK-7422 add argmax for sparse vectors
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Vector.argmax")
) ++ Seq(
// SPARK-8906 Move all internal data source classes into execution.datasources
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException")
)

case v if v.startsWith("1.4") =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import java.io.IOException;
import java.io.OutputStream;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.UTF8String;


Expand Down Expand Up @@ -354,7 +355,7 @@ public double getDouble(int i) {
* This method is only supported on UnsafeRows that do not use ObjectPools.
*/
@Override
public InternalRow copy() {
public UnsafeRow copy() {
if (pool != null) {
throw new UnsupportedOperationException(
"Copy is not supported for UnsafeRows that use object pools");
Expand Down Expand Up @@ -404,8 +405,51 @@ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOExcepti
}
}

@Override
public int hashCode() {
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
}

@Override
public boolean equals(Object other) {
if (other instanceof UnsafeRow) {
UnsafeRow o = (UnsafeRow) other;
return (sizeInBytes == o.sizeInBytes) &&
ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
sizeInBytes);
}
return false;
}

/**
* Returns the underlying bytes for this UnsafeRow.
*/
public byte[] getBytes() {
if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
&& (((byte[]) baseObject).length == sizeInBytes)) {
return (byte[]) baseObject;
} else {
byte[] bytes = new byte[sizeInBytes];
PlatformDependent.copyMemory(baseObject, baseOffset, bytes,
PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes);
return bytes;
}
}

// This is for debugging
@Override
public String toString() {
StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
build.append(',');
}
build.append(']');
return build.toString();
}

@Override
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
import org.apache.spark.TaskContext;
import org.apache.spark.sql.AbstractScalaRowIterator;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
Expand Down Expand Up @@ -176,12 +175,7 @@ public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IO
*/
public static boolean supportsSchema(StructType schema) {
// TODO: add spilling note to explain why we do this for now:
for (StructField field : schema.fields()) {
if (!UnsafeColumnWriter.canEmbed(field.dataType())) {
return false;
}
}
return true;
return UnsafeProjection.canSupport(schema);
}

private static final class RowComparator extends RecordComparator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
}
}
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
lexical.normalizeKeyword(udfName) match {
case "sum" => SumDistinct(exprs.head)
case "count" => CountDistinct(exprs)
case name => UnresolvedFunction(name, exprs, isDistinct = true)
case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -277,7 +278,7 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
Expand Down Expand Up @@ -517,9 +518,26 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan =>
q transformExpressions {
case u @ UnresolvedFunction(name, children) =>
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
registry.lookupFunction(name, children)
registry.lookupFunction(name, children) match {
// We get an aggregate function built based on AggregateFunction2 interface.
// So, we wrap it in AggregateExpression2.
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
// Currently, our old aggregate function interface supports SUM(DISTINCT ...)
// and COUTN(DISTINCT ...).
case sumDistinct: SumDistinct => sumDistinct
case countDistinct: CountDistinct => countDistinct
// DISTINCT is not meaningful with Max and Min.
case max: Max if isDistinct => max
case min: Min if isDistinct => min
// For other aggregate functions, DISTINCT keyword is not supported for now.
// Once we converted to the new code path, we will allow using DISTINCT keyword.
case other if isDistinct =>
failAnalysis(s"$name does not support DISTINCT keyword.")
// If it does not have DISTINCT keyword, we will return it as is.
case other => other
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ object FunctionRegistry {
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
expression[StringFormat]("printf"),
expression[FormatString]("format_string"),
expression[FormatString]("printf"),
expression[StringRPad]("rpad"),
expression[StringRepeat]("repeat"),
expression[StringReverse]("reverse"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ object UnresolvedAttribute {
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
}

case class UnresolvedFunction(name: String, children: Seq[Expression])
case class UnresolvedFunction(
name: String,
children: Seq[Expression],
isDistinct: Boolean)
extends Expression with Unevaluable {

override def dataType: DataType = throw new UnresolvedException(this, "dataType")
Expand Down
Loading

0 comments on commit ef5821d

Please sign in to comment.