(new byte[0], exceedsMaxRecordSize);
diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala
new file mode 100644
index 0000000000000..bc9f3708ed69d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.ByteBuffer
+
+import com.esotericsoftware.kryo.io.{Output, Input}
+import org.apache.avro.{SchemaBuilder, Schema}
+import org.apache.avro.generic.GenericData.Record
+
+import org.apache.spark.{SparkFunSuite, SharedSparkContext}
+
+class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
+ conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+
+ val schema : Schema = SchemaBuilder
+ .record("testRecord").fields()
+ .requiredString("data")
+ .endRecord()
+ val record = new Record(schema)
+ record.put("data", "test data")
+
+ test("schema compression and decompression") {
+ val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
+ assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
+ }
+
+ test("record serialization and deserialization") {
+ val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
+
+ val outputStream = new ByteArrayOutputStream()
+ val output = new Output(outputStream)
+ genericSer.serializeDatum(record, output)
+ output.flush()
+ output.close()
+
+ val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
+ assert(genericSer.deserializeDatum(input) === record)
+ }
+
+ test("uses schema fingerprint to decrease message size") {
+ val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)
+
+ val output = new Output(new ByteArrayOutputStream())
+
+ val beginningNormalPosition = output.total()
+ genericSerFull.serializeDatum(record, output)
+ output.flush()
+ val normalLength = output.total - beginningNormalPosition
+
+ conf.registerAvroSchemas(schema)
+ val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
+ val beginningFingerprintPosition = output.total()
+ genericSerFinger.serializeDatum(record, output)
+ val fingerprintLength = output.total - beginningFingerprintPosition
+
+ assert(fingerprintLength < normalLength)
+ }
+
+ test("caches previously seen schemas") {
+ val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
+ val compressedSchema = genericSer.compress(schema)
+ val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema))
+
+ assert(compressedSchema.eq(genericSer.compress(schema)))
+ assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
index 96778c9ebafb1..f495b6a037958 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
@@ -17,26 +17,39 @@
package org.apache.spark.shuffle
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.mockito.Mockito._
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
-import java.util.concurrent.atomic.AtomicBoolean
-import java.util.concurrent.CountDownLatch
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, TaskContext}
class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
+
+ val nextTaskAttemptId = new AtomicInteger()
+
/** Launch a thread with the given body block and return it. */
private def startThread(name: String)(body: => Unit): Thread = {
val thread = new Thread("ShuffleMemorySuite " + name) {
override def run() {
- body
+ try {
+ val taskAttemptId = nextTaskAttemptId.getAndIncrement
+ val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS)
+ when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId)
+ TaskContext.setTaskContext(mockTaskContext)
+ body
+ } finally {
+ TaskContext.unset()
+ }
}
}
thread.start()
thread
}
- test("single thread requesting memory") {
+ test("single task requesting memory") {
val manager = new ShuffleMemoryManager(1000L)
assert(manager.tryToAcquire(100L) === 100L)
@@ -50,7 +63,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
assert(manager.tryToAcquire(300L) === 300L)
assert(manager.tryToAcquire(300L) === 200L)
- manager.releaseMemoryForThisThread()
+ manager.releaseMemoryForThisTask()
assert(manager.tryToAcquire(1000L) === 1000L)
assert(manager.tryToAcquire(100L) === 0L)
}
@@ -107,8 +120,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
}
- test("threads cannot grow past 1 / N") {
- // Two threads request 250 bytes first, wait for each other to get it, and then request
+ test("tasks cannot grow past 1 / N") {
+ // Two tasks request 250 bytes first, wait for each other to get it, and then request
// 500 more; we should only grant 250 bytes to each of them on this second request
val manager = new ShuffleMemoryManager(1000L)
@@ -158,7 +171,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
assert(state.t2Result2 === 250L)
}
- test("threads can block to get at least 1 / 2N memory") {
+ test("tasks can block to get at least 1 / 2N memory") {
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
// for a bit and releases 250 bytes, which should then be granted to t2. Further requests
// by t2 will return false right away because it now has 1 / 2N of the memory.
@@ -224,7 +237,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
}
}
- test("releaseMemoryForThisThread") {
+ test("releaseMemoryForThisTask") {
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
// for a bit and releases all its memory. t2 should now be able to grab all the memory.
@@ -251,9 +264,9 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
}
}
// Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
- // sure the other thread blocks for some time otherwise
+ // sure the other task blocks for some time otherwise
Thread.sleep(300)
- manager.releaseMemoryForThisThread()
+ manager.releaseMemoryForThisTask()
}
val t2 = startThread("t2") {
@@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
t2.join()
}
- // Both threads should've been able to acquire their memory; the second one will have waited
+ // Both tasks should've been able to acquire their memory; the second one will have waited
// until the first one acquired 1000 bytes and then released all of it
state.synchronized {
assert(state.t1Result === 1000L, "t1 could not allocate memory")
@@ -293,7 +306,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
}
}
- test("threads should not be granted a negative size") {
+ test("tasks should not be granted a negative size") {
val manager = new ShuffleMemoryManager(1000L)
manager.tryToAcquire(700L)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index bcee901f5dd5f..f480fd107a0c2 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -1004,32 +1004,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
store = makeBlockManager(12000)
val memoryStore = store.memoryStore
assert(memoryStore.currentUnrollMemory === 0)
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
// Reserve
- memoryStore.reserveUnrollMemoryForThisThread(100)
- assert(memoryStore.currentUnrollMemoryForThisThread === 100)
- memoryStore.reserveUnrollMemoryForThisThread(200)
- assert(memoryStore.currentUnrollMemoryForThisThread === 300)
- memoryStore.reserveUnrollMemoryForThisThread(500)
- assert(memoryStore.currentUnrollMemoryForThisThread === 800)
- memoryStore.reserveUnrollMemoryForThisThread(1000000)
- assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted
+ memoryStore.reserveUnrollMemoryForThisTask(100)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 100)
+ memoryStore.reserveUnrollMemoryForThisTask(200)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 300)
+ memoryStore.reserveUnrollMemoryForThisTask(500)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 800)
+ memoryStore.reserveUnrollMemoryForThisTask(1000000)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted
// Release
- memoryStore.releaseUnrollMemoryForThisThread(100)
- assert(memoryStore.currentUnrollMemoryForThisThread === 700)
- memoryStore.releaseUnrollMemoryForThisThread(100)
- assert(memoryStore.currentUnrollMemoryForThisThread === 600)
+ memoryStore.releaseUnrollMemoryForThisTask(100)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 700)
+ memoryStore.releaseUnrollMemoryForThisTask(100)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 600)
// Reserve again
- memoryStore.reserveUnrollMemoryForThisThread(4400)
- assert(memoryStore.currentUnrollMemoryForThisThread === 5000)
- memoryStore.reserveUnrollMemoryForThisThread(20000)
- assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted
+ memoryStore.reserveUnrollMemoryForThisTask(4400)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 5000)
+ memoryStore.reserveUnrollMemoryForThisTask(20000)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted
// Release again
- memoryStore.releaseUnrollMemoryForThisThread(1000)
- assert(memoryStore.currentUnrollMemoryForThisThread === 4000)
- memoryStore.releaseUnrollMemoryForThisThread() // release all
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ memoryStore.releaseUnrollMemoryForThisTask(1000)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 4000)
+ memoryStore.releaseUnrollMemoryForThisTask() // release all
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
/**
@@ -1060,24 +1060,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
val bigList = List.fill(40)(new Array[Byte](1000))
val memoryStore = store.memoryStore
val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
// Unroll with all the space in the world. This should succeed and return an array.
var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks)
verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true)
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
- memoryStore.releasePendingUnrollMemoryForThisThread()
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+ memoryStore.releasePendingUnrollMemoryForThisTask()
// Unroll with not enough space. This should succeed after kicking out someBlock1.
store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY)
store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY)
unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks)
verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true)
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
assert(droppedBlocks.size === 1)
assert(droppedBlocks.head._1 === TestBlockId("someBlock1"))
droppedBlocks.clear()
- memoryStore.releasePendingUnrollMemoryForThisThread()
+ memoryStore.releasePendingUnrollMemoryForThisTask()
// Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 =
// 4800 bytes, there is still not enough room to unroll this block. This returns an iterator.
@@ -1085,7 +1085,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY)
unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks)
verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false)
- assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+ assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
assert(droppedBlocks.size === 1)
assert(droppedBlocks.head._1 === TestBlockId("someBlock2"))
droppedBlocks.clear()
@@ -1099,7 +1099,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
val bigList = List.fill(40)(new Array[Byte](1000))
def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]]
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
// Unroll with plenty of space. This should succeed and cache both blocks.
val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true)
@@ -1110,7 +1110,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(result2.size > 0)
assert(result1.data.isLeft) // unroll did not drop this block to disk
assert(result2.data.isLeft)
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
// Re-put these two blocks so block manager knows about them too. Otherwise, block manager
// would not know how to drop them from memory later.
@@ -1126,7 +1126,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(!memoryStore.contains("b1"))
assert(memoryStore.contains("b2"))
assert(memoryStore.contains("b3"))
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
memoryStore.remove("b3")
store.putIterator("b3", smallIterator, memOnly)
@@ -1138,7 +1138,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(!memoryStore.contains("b2"))
assert(memoryStore.contains("b3"))
assert(!memoryStore.contains("b4"))
- assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+ assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
}
/**
@@ -1153,7 +1153,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
val bigList = List.fill(40)(new Array[Byte](1000))
def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]]
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
store.putIterator("b1", smallIterator, memAndDisk)
store.putIterator("b2", smallIterator, memAndDisk)
@@ -1170,7 +1170,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(!diskStore.contains("b3"))
memoryStore.remove("b3")
store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY)
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
// Unroll huge block with not enough space. This should fail and drop the new block to disk
// directly in addition to kicking out b2 in the process. Memory store should contain only
@@ -1186,7 +1186,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(diskStore.contains("b2"))
assert(!diskStore.contains("b3"))
assert(diskStore.contains("b4"))
- assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+ assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
}
test("multiple unrolls by the same thread") {
@@ -1195,32 +1195,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
val memoryStore = store.memoryStore
val smallList = List.fill(40)(new Array[Byte](100))
def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
// All unroll memory used is released because unrollSafely returned an array
memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true)
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true)
- assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(memoryStore.currentUnrollMemoryForThisTask === 0)
// Unroll memory is not released because unrollSafely returned an iterator
// that still depends on the underlying vector used in the process
memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true)
- val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread
+ val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask
assert(unrollMemoryAfterB3 > 0)
// The unroll memory owned by this thread builds on top of its value after the previous unrolls
memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true)
- val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread
+ val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask
assert(unrollMemoryAfterB4 > unrollMemoryAfterB3)
// ... but only to a certain extent (until we run out of free space to grant new unroll memory)
memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true)
- val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread
+ val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask
memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true)
- val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread
+ val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask
memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true)
- val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread
+ val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask
assert(unrollMemoryAfterB5 === unrollMemoryAfterB4)
assert(unrollMemoryAfterB6 === unrollMemoryAfterB4)
assert(unrollMemoryAfterB7 === unrollMemoryAfterB4)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index dc03e374b51db..26a2e96edaaa2 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -17,22 +17,29 @@
package org.apache.spark.util.collection.unsafe.sort
+import com.google.common.primitives.UnsignedBytes
import org.scalatest.prop.PropertyChecks
-
import org.apache.spark.SparkFunSuite
+import org.apache.spark.unsafe.types.UTF8String
class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
test("String prefix comparator") {
def testPrefixComparison(s1: String, s2: String): Unit = {
- val s1Prefix = PrefixComparators.STRING.computePrefix(s1)
- val s2Prefix = PrefixComparators.STRING.computePrefix(s2)
+ val utf8string1 = UTF8String.fromString(s1)
+ val utf8string2 = UTF8String.fromString(s2)
+ val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1)
+ val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2)
val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)
+
+ val cmp = UnsignedBytes.lexicographicalComparator().compare(
+ utf8string1.getBytes.take(8), utf8string2.getBytes.take(8))
+
assert(
- (prefixComparisonResult == 0) ||
- (prefixComparisonResult < 0 && s1 < s2) ||
- (prefixComparisonResult > 0 && s1 > s2))
+ (prefixComparisonResult == 0 && cmp == 0) ||
+ (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) ||
+ (prefixComparisonResult > 0 && s1.compareTo(s2) > 0))
}
// scalastyle:off
@@ -48,27 +55,15 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
}
- test("float prefix comparator handles NaN properly") {
- val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
- val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
- assert(nan1.isNaN)
- assert(nan2.isNaN)
- val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
- val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
- assert(nan1Prefix === nan2Prefix)
- val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
- assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
- }
-
test("double prefix comparator handles NaNs properly") {
val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
assert(nan1.isNaN)
assert(nan2.isNaN)
- val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1)
- val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2)
+ val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1)
+ val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2)
assert(nan1Prefix === nan2Prefix)
- val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue)
+ val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
}
diff --git a/dev/run-tests.py b/dev/run-tests.py
index 1f0d218514f92..29420da9aa956 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -85,6 +85,13 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe
return [f for f in raw_output.split('\n') if f]
+def setup_test_environ(environ):
+ print("[info] Setup the following environment variables for tests: ")
+ for (k, v) in environ.items():
+ print("%s=%s" % (k, v))
+ os.environ[k] = v
+
+
def determine_modules_to_test(changed_modules):
"""
Given a set of modules that have changed, compute the transitive closure of those modules'
@@ -455,6 +462,15 @@ def main():
print("[info] Found the following changed modules:",
", ".join(x.name for x in changed_modules))
+ # setup environment variables
+ # note - the 'root' module doesn't collect environment variables for all modules. Because the
+ # environment variables should not be set if a module is not changed, even if running the 'root'
+ # module. So here we should use changed_modules rather than test_modules.
+ test_environ = {}
+ for m in changed_modules:
+ test_environ.update(m.environ)
+ setup_test_environ(test_environ)
+
test_modules = determine_modules_to_test(changed_modules)
# license checks
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 3073d489bad4a..44600cb9523c1 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -29,7 +29,7 @@ class Module(object):
changed.
"""
- def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(),
+ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={},
sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(),
should_run_r_tests=False):
"""
@@ -43,6 +43,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=
filename strings.
:param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in
order to build and test this module (e.g. '-PprofileName').
+ :param environ: A dict of environment variables that should be set when files in this
+ module are changed.
:param sbt_test_goals: A set of SBT test goals for testing this module.
:param python_test_goals: A set of Python test goals for testing this module.
:param blacklisted_python_implementations: A set of Python implementations that are not
@@ -55,6 +57,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=
self.source_file_prefixes = source_file_regexes
self.sbt_test_goals = sbt_test_goals
self.build_profile_flags = build_profile_flags
+ self.environ = environ
self.python_test_goals = python_test_goals
self.blacklisted_python_implementations = blacklisted_python_implementations
self.should_run_r_tests = should_run_r_tests
@@ -126,15 +129,22 @@ def contains_file(self, filename):
)
+# Don't set the dependencies because changes in other modules should not trigger Kinesis tests.
+# Kinesis tests depends on external Amazon kinesis service. We should run these tests only when
+# files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't
+# fail other PRs.
streaming_kinesis_asl = Module(
name="kinesis-asl",
- dependencies=[streaming],
+ dependencies=[],
source_file_regexes=[
"extras/kinesis-asl/",
],
build_profile_flags=[
"-Pkinesis-asl",
],
+ environ={
+ "ENABLE_KINESIS_TESTS": "1"
+ },
sbt_test_goals=[
"kinesis-asl/test",
]
@@ -313,7 +323,7 @@ def contains_file(self, filename):
"pyspark.mllib.evaluation",
"pyspark.mllib.feature",
"pyspark.mllib.fpm",
- "pyspark.mllib.linalg",
+ "pyspark.mllib.linalg.__init__",
"pyspark.mllib.random",
"pyspark.mllib.recommendation",
"pyspark.mllib.regression",
diff --git a/docs/configuration.md b/docs/configuration.md
index 200f3cd212e46..fd236137cb96e 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -203,7 +203,7 @@ Apart from these, the following properties are also available, and may be useful
spark.driver.extraClassPath
(none)
- Extra classpath entries to append to the classpath of the driver.
+ Extra classpath entries to prepend to the classpath of the driver.
Note: In client mode, this config must not be set through the SparkConf
directly in your application, because the driver JVM has already started at that point.
@@ -250,7 +250,7 @@ Apart from these, the following properties are also available, and may be useful
spark.executor.extraClassPath
(none)
- Extra classpath entries to append to the classpath of executors. This exists primarily for
+ Extra classpath entries to prepend to the classpath of executors. This exists primarily for
backwards-compatibility with older versions of Spark. Users typically should not need to set
this option.
diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md
new file mode 100644
index 0000000000000..4ca0bb06b26a6
--- /dev/null
+++ b/docs/mllib-evaluation-metrics.md
@@ -0,0 +1,1497 @@
+---
+layout: global
+title: Evaluation Metrics - MLlib
+displayTitle: MLlib - Evaluation Metrics
+---
+
+* Table of contents
+{:toc}
+
+Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions
+on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance
+of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a
+suite of metrics for the purpose of evaluating the performance of machine learning models.
+
+Specific machine learning algorithms fall under broader types of machine learning applications like classification,
+regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those
+metrics that are currently available in Spark's MLlib are detailed in this section.
+
+## Classification model evaluation
+
+While there are many different types of classification algorithms, the evaluation of classification models all share
+similar principles. In a [supervised classification problem](https://en.wikipedia.org/wiki/Statistical_classification),
+there exists a true output and a model-generated predicted output for each data point. For this reason, the results for
+each data point can be assigned to one of four categories:
+
+* True Positive (TP) - label is positive and prediction is also positive
+* True Negative (TN) - label is negative and prediction is also negative
+* False Positive (FP) - label is negative but prediction is positive
+* False Negative (FN) - label is positive but prediction is negative
+
+These four numbers are the building blocks for most classifier evaluation metrics. A fundamental point when considering
+classifier evaluation is that pure accuracy (i.e. was the prediction correct or incorrect) is not generally a good metric. The
+reason for this is because a dataset may be highly unbalanced. For example, if a model is designed to predict fraud from
+a dataset where 95% of the data points are _not fraud_ and 5% of the data points are _fraud_, then a naive classifier
+that predicts _not fraud_, regardless of input, will be 95% accurate. For this reason, metrics like
+[precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) are typically used because they take into
+account the *type* of error. In most applications there is some desired balance between precision and recall, which can
+be captured by combining the two into a single metric, called the [F-measure](https://en.wikipedia.org/wiki/F1_score).
+
+### Binary classification
+
+[Binary classifiers](https://en.wikipedia.org/wiki/Binary_classification) are used to separate the elements of a given
+dataset into one of two possible groups (e.g. fraud or not fraud) and is a special case of multiclass classification.
+Most binary classification metrics can be generalized to multiclass classification metrics.
+
+#### Threshold tuning
+
+It is import to understand that many classification models actually output a "score" (often times a probability) for
+each class, where a higher score indicates higher likelihood. In the binary case, the model may output a probability for
+each class: $P(Y=1|X)$ and $P(Y=0|X)$. Instead of simply taking the higher probability, there may be some cases where
+the model might need to be tuned so that it only predicts a class when the probability is very high (e.g. only block a
+credit card transaction if the model predicts fraud with >90% probability). Therefore, there is a prediction *threshold*
+which determines what the predicted class will be based on the probabilities that the model outputs.
+
+Tuning the prediction threshold will change the precision and recall of the model and is an important part of model
+optimization. In order to visualize how precision, recall, and other metrics change as a function of the threshold it is
+common practice to plot competing metrics against one another, parameterized by threshold. A P-R curve plots (precision,
+recall) points for different threshold values, while a
+[receiver operating characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), or ROC, curve
+plots (recall, false positive rate) points.
+
+**Available metrics**
+
+
+
+ Metric Definition
+
+
+
+ Precision (Postive Predictive Value)
+ $PPV=\frac{TP}{TP + FP}$
+
+
+ Recall (True Positive Rate)
+ $TPR=\frac{TP}{P}=\frac{TP}{TP + FN}$
+
+
+ F-measure
+ $F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR}
+ {\beta^2 \cdot PPV + TPR}\right)$
+
+
+ Receiver Operating Characteristic (ROC)
+ $FPR(T)=\int^\infty_{T} P_0(T)\,dT \\ TPR(T)=\int^\infty_{T} P_1(T)\,dT$
+
+
+ Area Under ROC Curve
+ $AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)$
+
+
+ Area Under Precision-Recall Curve
+ $AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)$
+
+
+
+
+
+**Examples**
+
+
+The following code snippets illustrate how to load a sample dataset, train a binary classification algorithm on the
+data, and evaluate the performance of the algorithm by several binary evaluation metrics.
+
+
+
+{% highlight scala %}
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLUtils
+
+// Load training data in LIBSVM format
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt")
+
+// Split data into training (60%) and test (40%)
+val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L)
+training.cache()
+
+// Run training algorithm to build the model
+val model = new LogisticRegressionWithLBFGS()
+ .setNumClasses(2)
+ .run(training)
+
+// Clear the prediction threshold so the model will return probabilities
+model.clearThreshold
+
+// Compute raw scores on the test set
+val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
+ val prediction = model.predict(features)
+ (prediction, label)
+}
+
+// Instantiate metrics object
+val metrics = new BinaryClassificationMetrics(predictionAndLabels)
+
+// Precision by threshold
+val precision = metrics.precisionByThreshold
+precision.foreach { case (t, p) =>
+ println(s"Threshold: $t, Precision: $p")
+}
+
+// Recall by threshold
+val recall = metrics.precisionByThreshold
+recall.foreach { case (t, r) =>
+ println(s"Threshold: $t, Recall: $r")
+}
+
+// Precision-Recall Curve
+val PRC = metrics.pr
+
+// F-measure
+val f1Score = metrics.fMeasureByThreshold
+f1Score.foreach { case (t, f) =>
+ println(s"Threshold: $t, F-score: $f, Beta = 1")
+}
+
+val beta = 0.5
+val fScore = metrics.fMeasureByThreshold(beta)
+f1Score.foreach { case (t, f) =>
+ println(s"Threshold: $t, F-score: $f, Beta = 0.5")
+}
+
+// AUPRC
+val auPRC = metrics.areaUnderPR
+println("Area under precision-recall curve = " + auPRC)
+
+// Compute thresholds used in ROC and PR curves
+val thresholds = precision.map(_._1)
+
+// ROC Curve
+val roc = metrics.roc
+
+// AUROC
+val auROC = metrics.areaUnderROC
+println("Area under ROC = " + auROC)
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight java %}
+import scala.Tuple2;
+
+import org.apache.spark.api.java.*;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.classification.LogisticRegressionModel;
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
+
+public class BinaryClassification {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics");
+ SparkContext sc = new SparkContext(conf);
+ String path = "data/mllib/sample_binary_classification_data.txt";
+ JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
+
+ // Split initial RDD into two... [60% training data, 40% testing data].
+ JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L);
+ JavaRDD training = splits[0].cache();
+ JavaRDD test = splits[1];
+
+ // Run training algorithm to build the model.
+ final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
+ .setNumClasses(2)
+ .run(training.rdd());
+
+ // Clear the prediction threshold so the model will return probabilities
+ model.clearThreshold();
+
+ // Compute raw scores on the test set.
+ JavaRDD> predictionAndLabels = test.map(
+ new Function>() {
+ public Tuple2 call(LabeledPoint p) {
+ Double prediction = model.predict(p.features());
+ return new Tuple2(prediction, p.label());
+ }
+ }
+ );
+
+ // Get evaluation metrics.
+ BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd());
+
+ // Precision by threshold
+ JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD();
+ System.out.println("Precision by threshold: " + precision.toArray());
+
+ // Recall by threshold
+ JavaRDD> recall = metrics.recallByThreshold().toJavaRDD();
+ System.out.println("Recall by threshold: " + recall.toArray());
+
+ // F Score by threshold
+ JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD();
+ System.out.println("F1 Score by threshold: " + f1Score.toArray());
+
+ JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();
+ System.out.println("F2 Score by threshold: " + f2Score.toArray());
+
+ // Precision-recall curve
+ JavaRDD> prc = metrics.pr().toJavaRDD();
+ System.out.println("Precision-recall curve: " + prc.toArray());
+
+ // Thresholds
+ JavaRDD thresholds = precision.map(
+ new Function, Double>() {
+ public Double call (Tuple2 t) {
+ return new Double(t._1().toString());
+ }
+ }
+ );
+
+ // ROC Curve
+ JavaRDD> roc = metrics.roc().toJavaRDD();
+ System.out.println("ROC curve: " + roc.toArray());
+
+ // AUPRC
+ System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR());
+
+ // AUROC
+ System.out.println("Area under ROC = " + metrics.areaUnderROC());
+
+ // Save and load model
+ model.save(sc, "myModelPath");
+ LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath");
+ }
+}
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight python %}
+from pyspark.mllib.classification import LogisticRegressionWithLBFGS
+from pyspark.mllib.evaluation import BinaryClassificationMetrics
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.util import MLUtils
+
+# Several of the methods available in scala are currently missing from pyspark
+
+# Load training data in LIBSVM format
+data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt")
+
+# Split data into training (60%) and test (40%)
+training, test = data.randomSplit([0.6, 0.4], seed = 11L)
+training.cache()
+
+# Run training algorithm to build the model
+model = LogisticRegressionWithLBFGS.train(training)
+
+# Compute raw scores on the test set
+predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label))
+
+# Instantiate metrics object
+metrics = BinaryClassificationMetrics(predictionAndLabels)
+
+# Area under precision-recall curve
+print "Area under PR = %s" % metrics.areaUnderPR
+
+# Area under ROC curve
+print "Area under ROC = %s" % metrics.areaUnderROC
+
+{% endhighlight %}
+
+
+
+
+
+### Multiclass classification
+
+A [multiclass classification](https://en.wikipedia.org/wiki/Multiclass_classification) describes a classification
+problem where there are $M \gt 2$ possible labels for each data point (the case where $M=2$ is the binary
+classification problem). For example, classifying handwriting samples to the digits 0 to 9, having 10 possible classes.
+
+For multiclass metrics, the notion of positives and negatives is slightly different. Predictions and labels can still
+be positive or negative, but they must be considered under the context of a particular class. Each label and prediction
+take on the value of one of the multiple classes and so they are said to be positive for their particular class and negative
+for all other classes. So, a true positive occurs whenever the prediction and the label match, while a true negative
+occurs when neither the prediction nor the label take on the value of a given class. By this convention, there can be
+multiple true negatives for a given data sample. The extension of false negatives and false positives from the former
+definitions of positive and negative labels is straightforward.
+
+#### Label based metrics
+
+Opposed to binary classification where there are only two possible labels, multiclass classification problems have many
+possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all
+labels - the number of times any class was predicted correctly (true positives) normalized by the number of data
+points. Precision by label considers only one class, and measures the number of time a specific label was predicted
+correctly normalized by the number of times that label appears in the output.
+
+**Available metrics**
+
+Define the class, or label, set as
+
+$$L = \{\ell_0, \ell_1, \ldots, \ell_{M-1} \} $$
+
+The true output vector $\mathbf{y}$ consists of $N$ elements
+
+$$\mathbf{y}_0, \mathbf{y}_1, \ldots, \mathbf{y}_{N-1} \in L $$
+
+A multiclass prediction algorithm generates a prediction vector $\hat{\mathbf{y}}$ of $N$ elements
+
+$$\hat{\mathbf{y}}_0, \hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_{N-1} \in L $$
+
+For this section, a modified delta function $\hat{\delta}(x)$ will prove useful
+
+$$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.\end{cases}$$
+
+
+
+ Metric Definition
+
+
+
+ Confusion Matrix
+
+ $C_{ij} = \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_i) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_j)\\ \\
+ \left( \begin{array}{ccc}
+ \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots &
+ \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) \\
+ \vdots & \ddots & \vdots \\
+ \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots &
+ \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N)
+ \end{array} \right)$
+
+
+
+ Overall Precision
+ $PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i -
+ \mathbf{y}_i\right)$
+
+
+ Overall Recall
+ $TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i -
+ \mathbf{y}_i\right)$
+
+
+ Overall F1-measure
+ $F1 = 2 \cdot \left(\frac{PPV \cdot TPR}
+ {PPV + TPR}\right)$
+
+
+ Precision by label
+ $PPV(\ell) = \frac{TP}{TP + FP} =
+ \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)}
+ {\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell)}$
+
+
+ Recall by label
+ $TPR(\ell)=\frac{TP}{P} =
+ \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)}
+ {\sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i - \ell)}$
+
+
+ F-measure by label
+ $F(\beta, \ell) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)}
+ {\beta^2 \cdot PPV(\ell) + TPR(\ell)}\right)$
+
+
+ Weighted precision
+ $PPV_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} PPV(\ell)
+ \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
+
+
+ Weighted recall
+ $TPR_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} TPR(\ell)
+ \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
+
+
+ Weighted F-measure
+ $F_{w}(\beta)= \frac{1}{N} \sum\nolimits_{\ell \in L} F(\beta, \ell)
+ \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
+
+
+
+
+**Examples**
+
+
+The following code snippets illustrate how to load a sample dataset, train a multiclass classification algorithm on
+the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics.
+
+
+
+{% highlight scala %}
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLUtils
+
+// Load training data in LIBSVM format
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt")
+
+// Split data into training (60%) and test (40%)
+val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L)
+training.cache()
+
+// Run training algorithm to build the model
+val model = new LogisticRegressionWithLBFGS()
+ .setNumClasses(3)
+ .run(training)
+
+// Compute raw scores on the test set
+val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
+ val prediction = model.predict(features)
+ (prediction, label)
+}
+
+// Instantiate metrics object
+val metrics = new MulticlassMetrics(predictionAndLabels)
+
+// Confusion matrix
+println("Confusion matrix:")
+println(metrics.confusionMatrix)
+
+// Overall Statistics
+val precision = metrics.precision
+val recall = metrics.recall // same as true positive rate
+val f1Score = metrics.fMeasure
+println("Summary Statistics")
+println(s"Precision = $precision")
+println(s"Recall = $recall")
+println(s"F1 Score = $f1Score")
+
+// Precision by label
+val labels = metrics.labels
+labels.foreach { l =>
+ println(s"Precision($l) = " + metrics.precision(l))
+}
+
+// Recall by label
+labels.foreach { l =>
+ println(s"Recall($l) = " + metrics.recall(l))
+}
+
+// False positive rate by label
+labels.foreach { l =>
+ println(s"FPR($l) = " + metrics.falsePositiveRate(l))
+}
+
+// F-measure by label
+labels.foreach { l =>
+ println(s"F1-Score($l) = " + metrics.fMeasure(l))
+}
+
+// Weighted stats
+println(s"Weighted precision: ${metrics.weightedPrecision}")
+println(s"Weighted recall: ${metrics.weightedRecall}")
+println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
+println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight java %}
+import scala.Tuple2;
+
+import org.apache.spark.api.java.*;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.classification.LogisticRegressionModel;
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
+import org.apache.spark.mllib.evaluation.MulticlassMetrics;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
+
+public class MulticlassClassification {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics");
+ SparkContext sc = new SparkContext(conf);
+ String path = "data/mllib/sample_multiclass_classification_data.txt";
+ JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
+
+ // Split initial RDD into two... [60% training data, 40% testing data].
+ JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L);
+ JavaRDD training = splits[0].cache();
+ JavaRDD test = splits[1];
+
+ // Run training algorithm to build the model.
+ final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
+ .setNumClasses(3)
+ .run(training.rdd());
+
+ // Compute raw scores on the test set.
+ JavaRDD> predictionAndLabels = test.map(
+ new Function>() {
+ public Tuple2 call(LabeledPoint p) {
+ Double prediction = model.predict(p.features());
+ return new Tuple2(prediction, p.label());
+ }
+ }
+ );
+
+ // Get evaluation metrics.
+ MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
+
+ // Confusion matrix
+ Matrix confusion = metrics.confusionMatrix();
+ System.out.println("Confusion matrix: \n" + confusion);
+
+ // Overall statistics
+ System.out.println("Precision = " + metrics.precision());
+ System.out.println("Recall = " + metrics.recall());
+ System.out.println("F1 Score = " + metrics.fMeasure());
+
+ // Stats by labels
+ for (int i = 0; i < metrics.labels().length; i++) {
+ System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i]));
+ System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i]));
+ System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i]));
+ }
+
+ //Weighted stats
+ System.out.format("Weighted precision = %f\n", metrics.weightedPrecision());
+ System.out.format("Weighted recall = %f\n", metrics.weightedRecall());
+ System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure());
+ System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate());
+
+ // Save and load model
+ model.save(sc, "myModelPath");
+ LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath");
+ }
+}
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight python %}
+from pyspark.mllib.classification import LogisticRegressionWithLBFGS
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.evaluation import MulticlassMetrics
+
+# Load training data in LIBSVM format
+data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt")
+
+# Split data into training (60%) and test (40%)
+training, test = data.randomSplit([0.6, 0.4], seed = 11L)
+training.cache()
+
+# Run training algorithm to build the model
+model = LogisticRegressionWithLBFGS.train(training, numClasses=3)
+
+# Compute raw scores on the test set
+predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label))
+
+# Instantiate metrics object
+metrics = MulticlassMetrics(predictionAndLabels)
+
+# Overall statistics
+precision = metrics.precision()
+recall = metrics.recall()
+f1Score = metrics.fMeasure()
+print "Summary Stats"
+print "Precision = %s" % precision
+print "Recall = %s" % recall
+print "F1 Score = %s" % f1Score
+
+# Statistics by class
+labels = data.map(lambda lp: lp.label).distinct().collect()
+for label in sorted(labels):
+ print "Class %s precision = %s" % (label, metrics.precision(label))
+ print "Class %s recall = %s" % (label, metrics.recall(label))
+ print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))
+
+# Weighted stats
+print "Weighted recall = %s" % metrics.weightedRecall
+print "Weighted precision = %s" % metrics.weightedPrecision
+print "Weighted F(1) Score = %s" % metrics.weightedFMeasure()
+print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)
+print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate
+{% endhighlight %}
+
+
+
+
+### Multilabel classification
+
+A [multilabel classification](https://en.wikipedia.org/wiki/Multi-label_classification) problem involves mapping
+each sample in a dataset to a set of class labels. In this type of classification problem, the labels are not
+mutually exclusive. For example, when classifying a set of news articles into topics, a single article might be both
+science and politics.
+
+Because the labels are not mutually exclusive, the predictions and true labels are now vectors of label *sets*, rather
+than vectors of labels. Multilabel metrics, therefore, extend the fundamental ideas of precision, recall, etc. to
+operations on sets. For example, a true positive for a given class now occurs when that class exists in the predicted
+set and it exists in the true label set, for a specific data point.
+
+**Available metrics**
+
+Here we define a set $D$ of $N$ documents
+
+$$D = \left\{d_0, d_1, ..., d_{N-1}\right\}$$
+
+Define $L_0, L_1, ..., L_{N-1}$ to be a family of label sets and $P_0, P_1, ..., P_{N-1}$
+to be a family of prediction sets where $L_i$ and $P_i$ are the label set and prediction set, respectively, that
+correspond to document $d_i$.
+
+The set of all unique labels is given by
+
+$$L = \bigcup_{k=0}^{N-1} L_k$$
+
+The following definition of indicator function $I_A(x)$ on a set $A$ will be necessary
+
+$$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{cases}$$
+
+
+
+ Metric Definition
+
+
+
+ Precision $\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|P_i \cap L_i\right|}{\left|P_i\right|}$
+
+
+ Recall $\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|L_i \cap P_i\right|}{\left|L_i\right|}$
+
+
+ Accuracy
+
+ $\frac{1}{N} \sum_{i=0}^{N - 1} \frac{\left|L_i \cap P_i \right|}
+ {\left|L_i\right| + \left|P_i\right| - \left|L_i \cap P_i \right|}$
+
+
+
+ Precision by label $PPV(\ell)=\frac{TP}{TP + FP}=
+ \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)}
+ {\sum_{i=0}^{N-1} I_{P_i}(\ell)}$
+
+
+ Recall by label $TPR(\ell)=\frac{TP}{P}=
+ \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)}
+ {\sum_{i=0}^{N-1} I_{L_i}(\ell)}$
+
+
+ F1-measure by label $F1(\ell) = 2
+ \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)}
+ {PPV(\ell) + TPR(\ell)}\right)$
+
+
+ Hamming Loss
+
+ $\frac{1}{N \cdot \left|L\right|} \sum_{i=0}^{N - 1} \left|L_i\right| + \left|P_i\right| - 2\left|L_i
+ \cap P_i\right|$
+
+
+
+ Subset Accuracy
+ $\frac{1}{N} \sum_{i=0}^{N-1} I_{\{L_i\}}(P_i)$
+
+
+ F1 Measure
+ $\frac{1}{N} \sum_{i=0}^{N-1} 2 \frac{\left|P_i \cap L_i\right|}{\left|P_i\right| \cdot \left|L_i\right|}$
+
+
+ Micro precision
+ $\frac{TP}{TP + FP}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}
+ {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|P_i - L_i\right|}$
+
+
+ Micro recall
+ $\frac{TP}{TP + FN}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}
+ {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right|}$
+
+
+ Micro F1 Measure
+
+ $2 \cdot \frac{TP}{2 \cdot TP + FP + FN}=2 \cdot \frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}{2 \cdot
+ \sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right| + \sum_{i=0}^{N-1}
+ \left|P_i - L_i\right|}$
+
+
+
+
+
+**Examples**
+
+The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples
+use the fake prediction and label data for multilabel classification that is shown below.
+
+Document predictions:
+
+* doc 0 - predict 0, 1 - class 0, 2
+* doc 1 - predict 0, 2 - class 0, 1
+* doc 2 - predict none - class 0
+* doc 3 - predict 2 - class 2
+* doc 4 - predict 2, 0 - class 2, 0
+* doc 5 - predict 0, 1, 2 - class 0, 1
+* doc 6 - predict 1 - class 1, 2
+
+Predicted classes:
+
+* class 0 - doc 0, 1, 4, 5 (total 4)
+* class 1 - doc 0, 5, 6 (total 3)
+* class 2 - doc 1, 3, 4, 5 (total 4)
+
+True classes:
+
+* class 0 - doc 0, 1, 2, 4, 5 (total 5)
+* class 1 - doc 1, 5, 6 (total 3)
+* class 2 - doc 0, 3, 4, 6 (total 4)
+
+
+
+
+
+{% highlight scala %}
+import org.apache.spark.mllib.evaluation.MultilabelMetrics
+import org.apache.spark.rdd.RDD;
+
+val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize(
+ Seq((Array(0.0, 1.0), Array(0.0, 2.0)),
+ (Array(0.0, 2.0), Array(0.0, 1.0)),
+ (Array(), Array(0.0)),
+ (Array(2.0), Array(2.0)),
+ (Array(2.0, 0.0), Array(2.0, 0.0)),
+ (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)),
+ (Array(1.0), Array(1.0, 2.0))), 2)
+
+// Instantiate metrics object
+val metrics = new MultilabelMetrics(scoreAndLabels)
+
+// Summary stats
+println(s"Recall = ${metrics.recall}")
+println(s"Precision = ${metrics.precision}")
+println(s"F1 measure = ${metrics.f1Measure}")
+println(s"Accuracy = ${metrics.accuracy}")
+
+// Individual label stats
+metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}"))
+metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}"))
+metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}"))
+
+// Micro stats
+println(s"Micro recall = ${metrics.microRecall}")
+println(s"Micro precision = ${metrics.microPrecision}")
+println(s"Micro F1 measure = ${metrics.microF1Measure}")
+
+// Hamming loss
+println(s"Hamming loss = ${metrics.hammingLoss}")
+
+// Subset accuracy
+println(s"Subset accuracy = ${metrics.subsetAccuracy}")
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight java %}
+import scala.Tuple2;
+
+import org.apache.spark.api.java.*;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.mllib.evaluation.MultilabelMetrics;
+import org.apache.spark.SparkConf;
+import java.util.Arrays;
+import java.util.List;
+
+public class MultilabelClassification {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics");
+ JavaSparkContext sc = new JavaSparkContext(conf);
+
+ List> data = Arrays.asList(
+ new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}),
+ new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}),
+ new Tuple2(new double[]{}, new double[]{0.0}),
+ new Tuple2(new double[]{2.0}, new double[]{2.0}),
+ new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}),
+ new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}),
+ new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0})
+ );
+ JavaRDD> scoreAndLabels = sc.parallelize(data);
+
+ // Instantiate metrics object
+ MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd());
+
+ // Summary stats
+ System.out.format("Recall = %f\n", metrics.recall());
+ System.out.format("Precision = %f\n", metrics.precision());
+ System.out.format("F1 measure = %f\n", metrics.f1Measure());
+ System.out.format("Accuracy = %f\n", metrics.accuracy());
+
+ // Stats by labels
+ for (int i = 0; i < metrics.labels().length - 1; i++) {
+ System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i]));
+ System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i]));
+ System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i]));
+ }
+
+ // Micro stats
+ System.out.format("Micro recall = %f\n", metrics.microRecall());
+ System.out.format("Micro precision = %f\n", metrics.microPrecision());
+ System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure());
+
+ // Hamming loss
+ System.out.format("Hamming loss = %f\n", metrics.hammingLoss());
+
+ // Subset accuracy
+ System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy());
+
+ }
+}
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight python %}
+from pyspark.mllib.evaluation import MultilabelMetrics
+
+scoreAndLabels = sc.parallelize([
+ ([0.0, 1.0], [0.0, 2.0]),
+ ([0.0, 2.0], [0.0, 1.0]),
+ ([], [0.0]),
+ ([2.0], [2.0]),
+ ([2.0, 0.0], [2.0, 0.0]),
+ ([0.0, 1.0, 2.0], [0.0, 1.0]),
+ ([1.0], [1.0, 2.0])])
+
+# Instantiate metrics object
+metrics = MultilabelMetrics(scoreAndLabels)
+
+# Summary stats
+print "Recall = %s" % metrics.recall()
+print "Precision = %s" % metrics.precision()
+print "F1 measure = %s" % metrics.f1Measure()
+print "Accuracy = %s" % metrics.accuracy
+
+# Individual label stats
+labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect()
+for label in labels:
+ print "Class %s precision = %s" % (label, metrics.precision(label))
+ print "Class %s recall = %s" % (label, metrics.recall(label))
+ print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))
+
+# Micro stats
+print "Micro precision = %s" % metrics.microPrecision
+print "Micro recall = %s" % metrics.microRecall
+print "Micro F1 measure = %s" % metrics.microF1Measure
+
+# Hamming loss
+print "Hamming loss = %s" % metrics.hammingLoss
+
+# Subset accuracy
+print "Subset accuracy = %s" % metrics.subsetAccuracy
+
+{% endhighlight %}
+
+
+
+
+### Ranking systems
+
+The role of a ranking algorithm (often thought of as a [recommender system](https://en.wikipedia.org/wiki/Recommender_system))
+is to return to the user a set of relevant items or documents based on some training data. The definition of relevance
+may vary and is usually application specific. Ranking system metrics aim to quantify the effectiveness of these
+rankings or recommendations in various contexts. Some metrics compare a set of recommended documents to a ground truth
+set of relevant documents, while other metrics may incorporate numerical ratings explicitly.
+
+**Available metrics**
+
+A ranking system usually deals with a set of $M$ users
+
+$$U = \left\{u_0, u_1, ..., u_{M-1}\right\}$$
+
+Each user ($u_i$) having a set of $N$ ground truth relevant documents
+
+$$D_i = \left\{d_0, d_1, ..., d_{N-1}\right\}$$
+
+And a list of $Q$ recommended documents, in order of decreasing relevance
+
+$$R_i = \left[r_0, r_1, ..., r_{Q-1}\right]$$
+
+The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the
+sets and the effectiveness of the algorithms can be measured using the metrics listed below.
+
+It is necessary to define a function which, provided a recommended document and a set of ground truth relevant
+documents, returns a relevance score for the recommended document.
+
+$$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{cases}$$
+
+
+
+ Metric Definition Notes
+
+
+
+
+ Precision at k
+
+
+ $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} rel_{D_i}(R_i(j))}$
+
+
+ Precision at k is a measure of
+ how many of the first k recommended documents are in the set of true relevant documents averaged across all
+ users. In this metric, the order of the recommendations is not taken into account.
+
+
+
+ Mean Average Precision
+
+ $MAP=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{\left|D_i\right|} \sum_{j=0}^{Q-1} \frac{rel_{D_i}(R_i(j))}{j + 1}}$
+
+
+ MAP is a measure of how
+ many of the recommended documents are in the set of true relevant documents, where the
+ order of the recommendations is taken into account (i.e. penalty for highly relevant documents is higher).
+
+
+
+ Normalized Discounted Cumulative Gain
+
+ $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1}
+ \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\
+ \text{Where} \\
+ \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\
+ \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$
+
+
+ NDCG at k is a
+ measure of how many of the first k recommended documents are in the set of true relevant documents averaged
+ across all users. In contrast to precision at k, this metric takes into account the order of the recommendations
+ (documents are assumed to be in order of decreasing relevance).
+
+
+
+
+
+**Examples**
+
+The following code snippets illustrate how to load a sample dataset, train an alternating least squares recommendation
+model on the data, and evaluate the performance of the recommender by several ranking metrics. A brief summary of the
+methodology is provided below.
+
+MovieLens ratings are on a scale of 1-5:
+
+ * 5: Must see
+ * 4: Will enjoy
+ * 3: It's okay
+ * 2: Fairly bad
+ * 1: Awful
+
+So we should not recommend a movie if the predicted rating is less than 3.
+To map ratings to confidence scores, we use:
+
+ * 5 -> 2.5
+ * 4 -> 1.5
+ * 3 -> 0.5
+ * 2 -> -0.5
+ * 1 -> -1.5.
+
+This mappings means unobserved entries are generally between It's okay and Fairly bad. The semantics of 0 in this
+expanded world of non-positive weights are "the same as never having interacted at all."
+
+
+
+
+
+{% highlight scala %}
+import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics}
+import org.apache.spark.mllib.recommendation.{ALS, Rating}
+
+// Read in the ratings data
+val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line =>
+ val fields = line.split("::")
+ Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5)
+}.cache()
+
+// Map ratings to 1 or 0, 1 indicating a movie that should be recommended
+val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache()
+
+// Summarize ratings
+val numRatings = ratings.count()
+val numUsers = ratings.map(_.user).distinct().count()
+val numMovies = ratings.map(_.product).distinct().count()
+println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
+
+// Build the model
+val numIterations = 10
+val rank = 10
+val lambda = 0.01
+val model = ALS.train(ratings, rank, numIterations, lambda)
+
+// Define a function to scale ratings from 0 to 1
+def scaledRating(r: Rating): Rating = {
+ val scaledRating = math.max(math.min(r.rating, 1.0), 0.0)
+ Rating(r.user, r.product, scaledRating)
+}
+
+// Get sorted top ten predictions for each user and then scale from [0, 1]
+val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) =>
+ (user, recs.map(scaledRating))
+}
+
+// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document
+// Compare with top ten most relevant documents
+val userMovies = binarizedRatings.groupBy(_.user)
+val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) =>
+ (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray)
+}
+
+// Instantiate metrics object
+val metrics = new RankingMetrics(relevantDocuments)
+
+// Precision at K
+Array(1, 3, 5).foreach{ k =>
+ println(s"Precision at $k = ${metrics.precisionAt(k)}")
+}
+
+// Mean average precision
+println(s"Mean average precision = ${metrics.meanAveragePrecision}")
+
+// Normalized discounted cumulative gain
+Array(1, 3, 5).foreach{ k =>
+ println(s"NDCG at $k = ${metrics.ndcgAt(k)}")
+}
+
+// Get predictions for each data point
+val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating))
+val allRatings = ratings.map(r => ((r.user, r.product), r.rating))
+val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) =>
+ (predicted, actual)
+}
+
+// Get the RMSE using regression metrics
+val regressionMetrics = new RegressionMetrics(predictionsAndLabels)
+println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}")
+
+// R-squared
+println(s"R-squared = ${regressionMetrics.r2}")
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight java %}
+import scala.Tuple2;
+
+import org.apache.spark.api.java.*;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.function.Function;
+import java.util.*;
+import org.apache.spark.mllib.evaluation.RegressionMetrics;
+import org.apache.spark.mllib.evaluation.RankingMetrics;
+import org.apache.spark.mllib.recommendation.ALS;
+import org.apache.spark.mllib.recommendation.Rating;
+
+// Read in the ratings data
+public class Ranking {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("Ranking Metrics");
+ JavaSparkContext sc = new JavaSparkContext(conf);
+ String path = "data/mllib/sample_movielens_data.txt";
+ JavaRDD data = sc.textFile(path);
+ JavaRDD ratings = data.map(
+ new Function() {
+ public Rating call(String line) {
+ String[] parts = line.split("::");
+ return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5);
+ }
+ }
+ );
+ ratings.cache();
+
+ // Train an ALS model
+ final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01);
+
+ // Get top 10 recommendations for every user and scale ratings from 0 to 1
+ JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD();
+ JavaRDD> userRecsScaled = userRecs.map(
+ new Function, Tuple2>() {
+ public Tuple2 call(Tuple2 t) {
+ Rating[] scaledRatings = new Rating[t._2().length];
+ for (int i = 0; i < scaledRatings.length; i++) {
+ double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0);
+ scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating);
+ }
+ return new Tuple2(t._1(), scaledRatings);
+ }
+ }
+ );
+ JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled);
+
+ // Map ratings to 1 or 0, 1 indicating a movie that should be recommended
+ JavaRDD binarizedRatings = ratings.map(
+ new Function() {
+ public Rating call(Rating r) {
+ double binaryRating;
+ if (r.rating() > 0.0) {
+ binaryRating = 1.0;
+ }
+ else {
+ binaryRating = 0.0;
+ }
+ return new Rating(r.user(), r.product(), binaryRating);
+ }
+ }
+ );
+
+ // Group ratings by common user
+ JavaPairRDD> userMovies = binarizedRatings.groupBy(
+ new Function() {
+ public Object call(Rating r) {
+ return r.user();
+ }
+ }
+ );
+
+ // Get true relevant documents from all user ratings
+ JavaPairRDD> userMoviesList = userMovies.mapValues(
+ new Function, List>() {
+ public List call(Iterable docs) {
+ List products = new ArrayList();
+ for (Rating r : docs) {
+ if (r.rating() > 0.0) {
+ products.add(r.product());
+ }
+ }
+ return products;
+ }
+ }
+ );
+
+ // Extract the product id from each recommendation
+ JavaPairRDD> userRecommendedList = userRecommended.mapValues(
+ new Function>() {
+ public List call(Rating[] docs) {
+ List products = new ArrayList();
+ for (Rating r : docs) {
+ products.add(r.product());
+ }
+ return products;
+ }
+ }
+ );
+ JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values();
+
+ // Instantiate the metrics object
+ RankingMetrics metrics = RankingMetrics.of(relevantDocs);
+
+ // Precision and NDCG at k
+ Integer[] kVector = {1, 3, 5};
+ for (Integer k : kVector) {
+ System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k));
+ System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k));
+ }
+
+ // Mean average precision
+ System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision());
+
+ // Evaluate the model using numerical ratings and regression metrics
+ JavaRDD> userProducts = ratings.map(
+ new Function>() {
+ public Tuple2 call(Rating r) {
+ return new Tuple2(r.user(), r.product());
+ }
+ }
+ );
+ JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD(
+ model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map(
+ new Function, Object>>() {
+ public Tuple2, Object> call(Rating r){
+ return new Tuple2, Object>(
+ new Tuple2(r.user(), r.product()), r.rating());
+ }
+ }
+ ));
+ JavaRDD> ratesAndPreds =
+ JavaPairRDD.fromJavaRDD(ratings.map(
+ new Function, Object>>() {
+ public Tuple2, Object> call(Rating r){
+ return new Tuple2, Object>(
+ new Tuple2(r.user(), r.product()), r.rating());
+ }
+ }
+ )).join(predictions).values();
+
+ // Create regression metrics object
+ RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd());
+
+ // Root mean squared error
+ System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError());
+
+ // R-squared
+ System.out.format("R-squared = %f\n", regressionMetrics.r2());
+ }
+}
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight python %}
+from pyspark.mllib.recommendation import ALS, Rating
+from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics
+
+# Read in the ratings data
+lines = sc.textFile("data/mllib/sample_movielens_data.txt")
+
+def parseLine(line):
+ fields = line.split("::")
+ return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5)
+ratings = lines.map(lambda r: parseLine(r))
+
+# Train a model on to predict user-product ratings
+model = ALS.train(ratings, 10, 10, 0.01)
+
+# Get predicted ratings on all existing user-product pairs
+testData = ratings.map(lambda p: (p.user, p.product))
+predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating))
+
+ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating))
+scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1])
+
+# Instantiate regression metrics to compare predicted and actual ratings
+metrics = RegressionMetrics(scoreAndLabels)
+
+# Root mean sqaured error
+print "RMSE = %s" % metrics.rootMeanSquaredError
+
+# R-squared
+print "R-squared = %s" % metrics.r2
+
+{% endhighlight %}
+
+
+
+
+## Regression model evaluation
+
+[Regression analysis](https://en.wikipedia.org/wiki/Regression_analysis) is used when predicting a continuous output
+variable from a number of independent variables.
+
+**Available metrics**
+
+
+
+ Metric Definition
+
+
+
+ Mean Squared Error (MSE)
+ $MSE = \frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}$
+
+
+ Root Mean Squared Error (RMSE)
+ $RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$
+
+
+ Mean Absoloute Error (MAE)
+ $MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$
+
+
+ Coefficient of Determination $(R^2)$
+ $R^2=1 - \frac{MSE}{\text{VAR}(\mathbf{y}) \cdot (N-1)}=1-\frac{\sum_{i=0}^{N-1}
+ (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{\sum_{i=0}^{N-1}(\mathbf{y}_i-\bar{\mathbf{y}})^2}$
+
+
+ Explained Variance
+ $1 - \frac{\text{VAR}(\mathbf{y} - \mathbf{\hat{y}})}{\text{VAR}(\mathbf{y})}$
+
+
+
+
+**Examples**
+
+
+The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data,
+and evaluate the performance of the algorithm by several regression metrics.
+
+
+
+{% highlight scala %}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.regression.LinearRegressionModel
+import org.apache.spark.mllib.regression.LinearRegressionWithSGD
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.evaluation.RegressionMetrics
+import org.apache.spark.mllib.util.MLUtils
+
+// Load the data
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache()
+
+// Build the model
+val numIterations = 100
+val model = LinearRegressionWithSGD.train(data, numIterations)
+
+// Get predictions
+val valuesAndPreds = data.map{ point =>
+ val prediction = model.predict(point.features)
+ (prediction, point.label)
+}
+
+// Instantiate metrics object
+val metrics = new RegressionMetrics(valuesAndPreds)
+
+// Squared error
+println(s"MSE = ${metrics.meanSquaredError}")
+println(s"RMSE = ${metrics.rootMeanSquaredError}")
+
+// R-squared
+println(s"R-squared = ${metrics.r2}")
+
+// Mean absolute error
+println(s"MAE = ${metrics.meanAbsoluteError}")
+
+// Explained variance
+println(s"Explained variance = ${metrics.explainedVariance}")
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight java %}
+import scala.Tuple2;
+
+import org.apache.spark.api.java.*;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.regression.LinearRegressionModel;
+import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
+import org.apache.spark.mllib.evaluation.RegressionMetrics;
+import org.apache.spark.SparkConf;
+
+public class LinearRegression {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("Linear Regression Example");
+ JavaSparkContext sc = new JavaSparkContext(conf);
+
+ // Load and parse the data
+ String path = "data/mllib/sample_linear_regression_data.txt";
+ JavaRDD data = sc.textFile(path);
+ JavaRDD parsedData = data.map(
+ new Function() {
+ public LabeledPoint call(String line) {
+ String[] parts = line.split(" ");
+ double[] v = new double[parts.length - 1];
+ for (int i = 1; i < parts.length - 1; i++)
+ v[i - 1] = Double.parseDouble(parts[i].split(":")[1]);
+ return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v));
+ }
+ }
+ );
+ parsedData.cache();
+
+ // Building the model
+ int numIterations = 100;
+ final LinearRegressionModel model =
+ LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations);
+
+ // Evaluate model on training examples and compute training error
+ JavaRDD> valuesAndPreds = parsedData.map(
+ new Function>() {
+ public Tuple2 call(LabeledPoint point) {
+ double prediction = model.predict(point.features());
+ return new Tuple2(prediction, point.label());
+ }
+ }
+ );
+
+ // Instantiate metrics object
+ RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd());
+
+ // Squared error
+ System.out.format("MSE = %f\n", metrics.meanSquaredError());
+ System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError());
+
+ // R-squared
+ System.out.format("R Squared = %f\n", metrics.r2());
+
+ // Mean absolute error
+ System.out.format("MAE = %f\n", metrics.meanAbsoluteError());
+
+ // Explained variance
+ System.out.format("Explained Variance = %f\n", metrics.explainedVariance());
+
+ // Save and load model
+ model.save(sc.sc(), "myModelPath");
+ LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath");
+ }
+}
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight python %}
+from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD
+from pyspark.mllib.evaluation import RegressionMetrics
+from pyspark.mllib.linalg import DenseVector
+
+# Load and parse the data
+def parsePoint(line):
+ values = line.split()
+ return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]]))
+
+data = sc.textFile("data/mllib/sample_linear_regression_data.txt")
+parsedData = data.map(parsePoint)
+
+# Build the model
+model = LinearRegressionWithSGD.train(parsedData)
+
+# Get predictions
+valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label))
+
+# Instantiate metrics object
+metrics = RegressionMetrics(valuesAndPreds)
+
+# Squared Error
+print "MSE = %s" % metrics.meanSquaredError
+print "RMSE = %s" % metrics.rootMeanSquaredError
+
+# R-squared
+print "R-squared = %s" % metrics.r2
+
+# Mean absolute error
+print "MAE = %s" % metrics.meanAbsoluteError
+
+# Explained variance
+print "Explained variance = %s" % metrics.explainedVariance
+
+{% endhighlight %}
+
+
+
\ No newline at end of file
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index d2d1cc93fe006..eea864eacf7c4 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -48,6 +48,7 @@ This lists functionality included in `spark.mllib`, the main MLlib API.
* [Feature extraction and transformation](mllib-feature-extraction.html)
* [Frequent pattern mining](mllib-frequent-pattern-mining.html)
* FP-growth
+* [Evaluation Metrics](mllib-evaluation-metrics.html)
* [Optimization (developer)](mllib-optimization.html)
* stochastic gradient descent
* limited-memory BFGS (L-BFGS)
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 7c83d68e7993e..ccf922d9371fb 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -242,7 +242,7 @@ def parse_args():
help="Number of EBS volumes to attach to each node as /vol[x]. " +
"The volumes will be deleted when the instances terminate. " +
"Only possible on EBS-backed AMIs. " +
- "EBS volumes are only attached if --ebs-vol-size > 0." +
+ "EBS volumes are only attached if --ebs-vol-size > 0. " +
"Only support up to 8 EBS volumes.")
parser.add_option(
"--placement-group", type="string", default=None,
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
index 0ff1b7ed0fd90..ca39358b75cb6 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
@@ -53,6 +53,8 @@ private class KinesisTestUtils(
@volatile
private var streamCreated = false
+
+ @volatile
private var _streamName: String = _
private lazy val kinesisClient = {
@@ -115,21 +117,9 @@ private class KinesisTestUtils(
shardIdToSeqNumbers.toMap
}
- def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = {
- try {
- val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe)
- val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription()
- Some(desc)
- } catch {
- case rnfe: ResourceNotFoundException =>
- None
- }
- }
-
def deleteStream(): Unit = {
try {
- if (describeStream().nonEmpty) {
- val deleteStreamRequest = new DeleteStreamRequest()
+ if (streamCreated) {
kinesisClient.deleteStream(streamName)
}
} catch {
@@ -149,6 +139,17 @@ private class KinesisTestUtils(
}
}
+ private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = {
+ try {
+ val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe)
+ val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription()
+ Some(desc)
+ } catch {
+ case rnfe: ResourceNotFoundException =>
+ None
+ }
+ }
+
private def findNonExistentStreamName(): String = {
var testStreamName: String = null
do {
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
index b2e2a4246dbd5..e81fb11e5959f 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.streaming.kinesis
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.scalatest.BeforeAndAfterAll
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
-import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkContext, SparkException}
class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll {
@@ -65,6 +65,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
}
override def afterAll(): Unit = {
+ if (testUtils != null) {
+ testUtils.deleteStream()
+ }
if (sc != null) {
sc.stop()
}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 4992b041765e9..b88c9c6478d56 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -59,7 +59,7 @@ class KinesisStreamSuite extends KinesisFunSuite
}
}
- ignore("KinesisUtils API") {
+ test("KinesisUtils API") {
ssc = new StreamingContext(sc, Seconds(1))
// Tests the API, does not actually test data receiving
val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream",
@@ -83,16 +83,16 @@ class KinesisStreamSuite extends KinesisFunSuite
* you must have AWS credentials available through the default AWS provider chain,
* and you have to set the system environment variable RUN_KINESIS_TESTS=1 .
*/
- ignore("basic operation") {
+ testIfEnabled("basic operation") {
val kinesisTestUtils = new KinesisTestUtils()
try {
kinesisTestUtils.createStream()
ssc = new StreamingContext(sc, Seconds(1))
- val aWSCredentials = KinesisTestUtils.getAWSCredentials()
+ val awsCredentials = KinesisTestUtils.getAWSCredentials()
val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName,
kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST,
Seconds(10), StorageLevel.MEMORY_ONLY,
- aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey)
+ awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int]
stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd =>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index cfcf7244eaed5..2ca60d51f8331 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -127,28 +127,25 @@ object Pregel extends Logging {
var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
- // Receive the messages. Vertices that didn't get any messages do not appear in newVerts.
- val newVerts = g.vertices.innerJoin(messages)(vprog).cache()
- // Update the graph with the new vertices.
+ // Receive the messages and update the vertices.
prevG = g
- g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
- g.cache()
+ g = g.joinVertices(messages)(vprog).cache()
val oldMessages = messages
- // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't
- // get to send messages. We must cache messages so it can be materialized on the next line,
- // allowing us to uncache the previous iteration.
- messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache()
- // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This
- // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
- // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
+ // Send new messages, skipping edges where neither side received a message. We must cache
+ // messages so it can be materialized on the next line, allowing us to uncache the previous
+ // iteration.
+ messages = g.mapReduceTriplets(
+ sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
+ // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages
+ // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages
+ // and the vertices of g).
activeMessages = messages.count()
logInfo("Pregel finished iteration " + i)
// Unpersist the RDDs hidden by newly-materialized RDDs
oldMessages.unpersist(blocking = false)
- newVerts.unpersist(blocking = false)
prevG.unpersistVertices(blocking = false)
prevG.edges.unpersist(blocking = false)
// count the iteration
diff --git a/make-distribution.sh b/make-distribution.sh
index cac7032bb2e87..4789b0e09cc8a 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)"
DISTDIR="$SPARK_HOME/dist"
SPARK_TACHYON=false
-TACHYON_VERSION="0.6.4"
+TACHYON_VERSION="0.7.0"
TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz"
TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index fc0693f67cc2e..bc19bd6df894f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
@@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType
*/
@Experimental
final class RandomForestClassifier(override val uid: String)
- extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
+ extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
def this() = this(Identifiable.randomUID("rfc"))
@@ -98,7 +98,7 @@ final class RandomForestClassifier(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeClassificationModel])
- new RandomForestClassificationModel(trees)
+ new RandomForestClassificationModel(trees, numClasses)
}
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@@ -125,8 +125,9 @@ object RandomForestClassifier {
@Experimental
final class RandomForestClassificationModel private[ml] (
override val uid: String,
- private val _trees: Array[DecisionTreeClassificationModel])
- extends PredictionModel[Vector, RandomForestClassificationModel]
+ private val _trees: Array[DecisionTreeClassificationModel],
+ override val numClasses: Int)
+ extends ClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
@@ -135,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
* Construct a random forest classification model, with all trees weighted equally.
* @param trees Component trees
*/
- def this(trees: Array[DecisionTreeClassificationModel]) =
- this(Identifiable.randomUID("rfc"), trees)
+ def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
+ this(Identifiable.randomUID("rfc"), trees, numClasses)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@@ -153,20 +154,20 @@ final class RandomForestClassificationModel private[ml] (
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
- override protected def predict(features: Vector): Double = {
+ override protected def predictRaw(features: Vector): Vector = {
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the weights since all are 1.0 for now.
- val votes = mutable.Map.empty[Int, Double]
+ val votes = new Array[Double](numClasses)
_trees.view.foreach { tree =>
val prediction = tree.rootNode.predict(features).toInt
- votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
+ votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
}
- votes.maxBy(_._2)._1
+ Vectors.dense(votes)
}
override def copy(extra: ParamMap): RandomForestClassificationModel = {
- copyValues(new RandomForestClassificationModel(uid, _trees), extra)
+ copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
}
override def toString: String = {
@@ -185,7 +186,8 @@ private[ml] object RandomForestClassificationModel {
def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
- categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): RandomForestClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
@@ -193,6 +195,6 @@ private[ml] object RandomForestClassificationModel {
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
- new RandomForestClassificationModel(uid, newTrees)
+ new RandomForestClassificationModel(uid, newTrees, numClasses)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 3825942795645..9c60d4084ec46 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
- val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)
@@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
val outputAttrNames: Option[Array[String]] = inputAttr match {
case nominal: NominalAttribute =>
if (nominal.values.isDefined) {
- nominal.values.map(_.map(v => inputColName + is + v))
+ nominal.values
} else if (nominal.numValues.isDefined) {
- nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
+ nominal.numValues.map(n => Array.tabulate(n)(_.toString))
} else {
None
}
case binary: BinaryAttribute =>
if (binary.values.isDefined) {
- binary.values.map(_.map(v => inputColName + is + v))
+ binary.values
} else {
- Some(Array.tabulate(2)(i => inputColName + is + i))
+ Some(Array.tabulate(2)(_.toString))
}
case _: NumericAttribute =>
throw new RuntimeException(
@@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
override def transform(dataset: DataFrame): DataFrame = {
// schema transformation
- val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)
val shouldDropLast = $(dropLast)
@@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
math.max(m0, m1)
}
).toInt + 1
- val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
+ val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
val outputAttrs: Array[Attribute] =
filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 0a95b1ee8de6e..d1726917e4517 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.parsing.combinator.RegexParsers
@@ -78,17 +79,33 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
/** @group getParam */
def getFormula: String = $(formula)
+ /** Whether the formula specifies fitting an intercept. */
+ private[ml] def hasIntercept: Boolean = {
+ require(parsedFormula.isDefined, "Must call setFormula() first.")
+ parsedFormula.get.hasIntercept
+ }
+
override def fit(dataset: DataFrame): RFormulaModel = {
require(parsedFormula.isDefined, "Must call setFormula() first.")
+ val resolvedFormula = parsedFormula.get.resolve(dataset.schema)
// StringType terms and terms representing interactions need to be encoded before assembly.
// TODO(ekl) add support for feature interactions
- var encoderStages = ArrayBuffer[PipelineStage]()
- var tempColumns = ArrayBuffer[String]()
- val encodedTerms = parsedFormula.get.terms.map { term =>
+ val encoderStages = ArrayBuffer[PipelineStage]()
+ val tempColumns = ArrayBuffer[String]()
+ val takenNames = mutable.Set(dataset.columns: _*)
+ val encodedTerms = resolvedFormula.terms.map { term =>
dataset.schema(term) match {
case column if column.dataType == StringType =>
val indexCol = term + "_idx_" + uid
- val encodedCol = term + "_onehot_" + uid
+ val encodedCol = {
+ var tmp = term
+ while (takenNames.contains(tmp)) {
+ tmp += "_"
+ }
+ tmp
+ }
+ takenNames.add(indexCol)
+ takenNames.add(encodedCol)
encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
tempColumns += indexCol
@@ -103,7 +120,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
.setOutputCol($(featuresCol))
encoderStages += new ColumnPruner(tempColumns.toSet)
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
- copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this))
+ copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
}
// optimistic schema; does not contain any ML attributes
@@ -124,13 +141,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
/**
* :: Experimental ::
* A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
- * @param parsedFormula a pre-parsed R formula.
+ * @param resolvedFormula the fitted R formula.
* @param pipelineModel the fitted feature model, including factor to index mappings.
*/
@Experimental
class RFormulaModel private[feature](
override val uid: String,
- parsedFormula: ParsedRFormula,
+ resolvedFormula: ResolvedRFormula,
pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase {
@@ -144,8 +161,8 @@ class RFormulaModel private[feature](
val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(schema)) {
withFeatures
- } else if (schema.exists(_.name == parsedFormula.label)) {
- val nullable = schema(parsedFormula.label).dataType match {
+ } else if (schema.exists(_.name == resolvedFormula.label)) {
+ val nullable = schema(resolvedFormula.label).dataType match {
case _: NumericType | BooleanType => false
case _ => true
}
@@ -158,12 +175,12 @@ class RFormulaModel private[feature](
}
override def copy(extra: ParamMap): RFormulaModel = copyValues(
- new RFormulaModel(uid, parsedFormula, pipelineModel))
+ new RFormulaModel(uid, resolvedFormula, pipelineModel))
- override def toString: String = s"RFormulaModel(${parsedFormula})"
+ override def toString: String = s"RFormulaModel(${resolvedFormula})"
private def transformLabel(dataset: DataFrame): DataFrame = {
- val labelName = parsedFormula.label
+ val labelName = resolvedFormula.label
if (hasLabelCol(dataset.schema)) {
dataset
} else if (dataset.schema.exists(_.name == labelName)) {
@@ -207,26 +224,3 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
}
-
-/**
- * Represents a parsed R formula.
- */
-private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
-
-/**
- * Limited implementation of R formula parsing. Currently supports: '~', '+'.
- */
-private[ml] object RFormulaParser extends RegexParsers {
- def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
-
- def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
-
- def formula: Parser[ParsedRFormula] =
- (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) }
-
- def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
- case Success(result, _) => result
- case failure: NoSuccess => throw new IllegalArgumentException(
- "Could not parse formula: " + value)
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
new file mode 100644
index 0000000000000..1ca3b92a7d92a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.parsing.combinator.RegexParsers
+
+import org.apache.spark.mllib.linalg.VectorUDT
+import org.apache.spark.sql.types._
+
+/**
+ * Represents a parsed R formula.
+ */
+private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
+ /**
+ * Resolves formula terms into column names. A schema is necessary for inferring the meaning
+ * of the special '.' term. Duplicate terms will be removed during resolution.
+ */
+ def resolve(schema: StructType): ResolvedRFormula = {
+ var includedTerms = Seq[String]()
+ terms.foreach {
+ case Dot =>
+ includedTerms ++= simpleTypes(schema).filter(_ != label.value)
+ case ColumnRef(value) =>
+ includedTerms :+= value
+ case Deletion(term: Term) =>
+ term match {
+ case ColumnRef(value) =>
+ includedTerms = includedTerms.filter(_ != value)
+ case Dot =>
+ // e.g. "- .", which removes all first-order terms
+ val fromSchema = simpleTypes(schema)
+ includedTerms = includedTerms.filter(fromSchema.contains(_))
+ case _: Deletion =>
+ assert(false, "Deletion terms cannot be nested")
+ case _: Intercept =>
+ }
+ case _: Intercept =>
+ }
+ ResolvedRFormula(label.value, includedTerms.distinct)
+ }
+
+ /** Whether this formula specifies fitting with an intercept term. */
+ def hasIntercept: Boolean = {
+ var intercept = true
+ terms.foreach {
+ case Intercept(enabled) =>
+ intercept = enabled
+ case Deletion(Intercept(enabled)) =>
+ intercept = !enabled
+ case _ =>
+ }
+ intercept
+ }
+
+ // the dot operator excludes complex column types
+ private def simpleTypes(schema: StructType): Seq[String] = {
+ schema.fields.filter(_.dataType match {
+ case _: NumericType | StringType | BooleanType | _: VectorUDT => true
+ case _ => false
+ }).map(_.name)
+ }
+}
+
+/**
+ * Represents a fully evaluated and simplified R formula.
+ */
+private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])
+
+/**
+ * R formula terms. See the R formula docs here for more information:
+ * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+ */
+private[ml] sealed trait Term
+
+/* R formula reference to all available columns, e.g. "." in a formula */
+private[ml] case object Dot extends Term
+
+/* R formula reference to a column, e.g. "+ Species" in a formula */
+private[ml] case class ColumnRef(value: String) extends Term
+
+/* R formula intercept toggle, e.g. "+ 0" in a formula */
+private[ml] case class Intercept(enabled: Boolean) extends Term
+
+/* R formula deletion of a variable, e.g. "- Species" in a formula */
+private[ml] case class Deletion(term: Term) extends Term
+
+/**
+ * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'.
+ */
+private[ml] object RFormulaParser extends RegexParsers {
+ def intercept: Parser[Intercept] =
+ "([01])".r ^^ { case a => Intercept(a == "1") }
+
+ def columnRef: Parser[ColumnRef] =
+ "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
+
+ def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot }
+
+ def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
+ case op ~ list => list.foldLeft(List(op)) {
+ case (left, "+" ~ right) => left ++ Seq(right)
+ case (left, "-" ~ right) => left ++ Seq(Deletion(right))
+ }
+ }
+
+ def formula: Parser[ParsedRFormula] =
+ (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+
+ def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => throw new IllegalArgumentException(
+ "Could not parse formula: " + value)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 0b3af4747e693..248288ca73e99 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -50,7 +50,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
/**
* :: Experimental ::
* A regex based tokenizer that extracts tokens either by using the provided regex pattern to split
- * the text (default) or repeatedly matching the regex (if `gaps` is true).
+ * the text (default) or repeatedly matching the regex (if `gaps` is false).
* Optional parameters also allow filtering tokens using a minimal length.
* It returns an array of strings that can be empty.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 1ee080641e3e3..f5a022c31ed90 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -17,9 +17,10 @@
package org.apache.spark.ml.api.r
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.RFormula
-import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.DataFrame
@@ -32,10 +33,38 @@ private[r] object SparkRWrappers {
alpha: Double): PipelineModel = {
val formula = new RFormula().setFormula(value)
val estimator = family match {
- case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha)
- case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha)
+ case "gaussian" => new LinearRegression()
+ .setRegParam(lambda)
+ .setElasticNetParam(alpha)
+ .setFitIntercept(formula.hasIntercept)
+ case "binomial" => new LogisticRegression()
+ .setRegParam(lambda)
+ .setElasticNetParam(alpha)
+ .setFitIntercept(formula.hasIntercept)
}
val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
}
+
+ def getModelWeights(model: PipelineModel): Array[Double] = {
+ model.stages.last match {
+ case m: LinearRegressionModel =>
+ Array(m.intercept) ++ m.weights.toArray
+ case _: LogisticRegressionModel =>
+ throw new UnsupportedOperationException(
+ "No weights available for LogisticRegressionModel") // SPARK-9492
+ }
+ }
+
+ def getModelFeatures(model: PipelineModel): Array[String] = {
+ model.stages.last match {
+ case m: LinearRegressionModel =>
+ val attrs = AttributeGroup.fromStructField(
+ m.summary.predictions.schema(m.summary.featuresCol))
+ Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ case _: LogisticRegressionModel =>
+ throw new UnsupportedOperationException(
+ "No features names available for LogisticRegressionModel") // SPARK-9492
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
new file mode 100644
index 0000000000000..4ece8cf8cf0b6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam}
+import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
+import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.{DoubleType, DataType}
+import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for isotonic regression.
+ */
+private[regression] trait IsotonicRegressionParams extends PredictorParams {
+
+ /**
+ * Param for weight column name.
+ * TODO: Move weightCol to sharedParams.
+ *
+ * @group param
+ */
+ final val weightCol: Param[String] =
+ new Param[String](this, "weightCol", "weight column name")
+
+ /** @group getParam */
+ final def getWeightCol: String = $(weightCol)
+
+ /**
+ * Param for isotonic parameter.
+ * Isotonic (increasing) or antitonic (decreasing) sequence.
+ * @group param
+ */
+ final val isotonic: BooleanParam =
+ new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence")
+
+ /** @group getParam */
+ final def getIsotonicParam: Boolean = $(isotonic)
+}
+
+/**
+ * :: Experimental ::
+ * Isotonic regression.
+ *
+ * Currently implemented using parallelized pool adjacent violators algorithm.
+ * Only univariate (single feature) algorithm supported.
+ *
+ * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]].
+ */
+@Experimental
+class IsotonicRegression(override val uid: String)
+ extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel]
+ with IsotonicRegressionParams {
+
+ def this() = this(Identifiable.randomUID("isoReg"))
+
+ /**
+ * Set the isotonic parameter.
+ * Default is true.
+ * @group setParam
+ */
+ def setIsotonicParam(value: Boolean): this.type = set(isotonic, value)
+ setDefault(isotonic -> true)
+
+ /**
+ * Set weight column param.
+ * Default is weight.
+ * @group setParam
+ */
+ def setWeightParam(value: String): this.type = set(weightCol, value)
+ setDefault(weightCol -> "weight")
+
+ override private[ml] def featuresDataType: DataType = DoubleType
+
+ override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
+
+ private[this] def extractWeightedLabeledPoints(
+ dataset: DataFrame): RDD[(Double, Double, Double)] = {
+
+ dataset.select($(labelCol), $(featuresCol), $(weightCol))
+ .map { case Row(label: Double, features: Double, weights: Double) =>
+ (label, features, weights)
+ }
+ }
+
+ override protected def train(dataset: DataFrame): IsotonicRegressionModel = {
+ SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType)
+ // Extract columns from data. If dataset is persisted, do not persist oldDataset.
+ val instances = extractWeightedLabeledPoints(dataset)
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic))
+ val parentModel = isotonicRegression.run(instances)
+
+ new IsotonicRegressionModel(uid, parentModel)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by IsotonicRegression.
+ * Predicts using a piecewise linear function.
+ *
+ * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]].
+ *
+ * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]]
+ * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]].
+ */
+class IsotonicRegressionModel private[ml] (
+ override val uid: String,
+ private[ml] val parentModel: MLlibIsotonicRegressionModel)
+ extends RegressionModel[Double, IsotonicRegressionModel]
+ with IsotonicRegressionParams {
+
+ override def featuresDataType: DataType = DoubleType
+
+ override protected def predict(features: Double): Double = {
+ parentModel.predict(features)
+ }
+
+ override def copy(extra: ParamMap): IsotonicRegressionModel = {
+ copyValues(new IsotonicRegressionModel(uid, parentModel), extra)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 89718e0f3e15a..3b85ba001b128 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructField
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
@@ -146,9 +147,10 @@ class LinearRegression(override val uid: String)
val model = new LinearRegressionModel(uid, weights, intercept)
val trainingSummary = new LinearRegressionTrainingSummary(
- model.transform(dataset).select($(predictionCol), $(labelCol)),
+ model.transform(dataset),
$(predictionCol),
$(labelCol),
+ $(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
}
@@ -221,9 +223,10 @@ class LinearRegression(override val uid: String)
val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
val trainingSummary = new LinearRegressionTrainingSummary(
- model.transform(dataset).select($(predictionCol), $(labelCol)),
+ model.transform(dataset),
$(predictionCol),
$(labelCol),
+ $(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
@@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
+ val featuresCol: String,
val objectiveHistory: Array[Double])
extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
new file mode 100644
index 0000000000000..0ec88ef77d695
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.api.python
+
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
+import org.apache.spark.mllib.clustering.GaussianMixtureModel
+
+/**
+ * Wrapper around GaussianMixtureModel to provide helper methods in Python
+ */
+private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
+ val weights: Vector = Vectors.dense(model.weights)
+ val k: Int = weights.size
+
+ /**
+ * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
+ */
+ val gaussians: JList[Object] = {
+ val modelGaussians = model.gaussians
+ var i = 0
+ var mu = ArrayBuffer.empty[Vector]
+ var sigma = ArrayBuffer.empty[Matrix]
+ while (i < k) {
+ mu += modelGaussians(i).mu
+ sigma += modelGaussians(i).sigma
+ i += 1
+ }
+ List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+ }
+
+ def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index fda8d5a0b048f..6f080d32bbf4d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable {
seed: java.lang.Long,
initialModelWeights: java.util.ArrayList[Double],
initialModelMu: java.util.ArrayList[Vector],
- initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
+ initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = {
val gmmAlg = new GaussianMixture()
.setK(k)
.setConvergenceTol(convergenceTol)
@@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable {
if (seed != null) gmmAlg.setSeed(seed)
try {
- val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
- var wt = ArrayBuffer.empty[Double]
- var mu = ArrayBuffer.empty[Vector]
- var sigma = ArrayBuffer.empty[Matrix]
- for (i <- 0 until model.k) {
- wt += model.weights(i)
- mu += model.gaussians(i).mu
- sigma += model.gaussians(i).sigma
- }
- List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+ new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)))
} finally {
data.rdd.unpersist(blocking = false)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 31c1d520fd659..6cfad3fbbdb87 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -17,10 +17,9 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV}
-
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
+import breeze.numerics.{exp, lgamma}
import org.apache.hadoop.fs.Path
-
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
@@ -28,14 +27,13 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
-import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph}
-import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector}
-import org.apache.spark.mllib.util.{Saveable, Loader}
+import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
+import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.BoundedPriorityQueue
-
/**
* :: Experimental ::
*
@@ -53,6 +51,31 @@ abstract class LDAModel private[clustering] extends Saveable {
/** Vocabulary size (number of terms or terms in the vocabulary) */
def vocabSize: Int
+ /**
+ * Concentration parameter (commonly named "alpha") for the prior placed on documents'
+ * distributions over topics ("theta").
+ *
+ * This is the parameter to a Dirichlet distribution.
+ */
+ def docConcentration: Vector
+
+ /**
+ * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
+ * distributions over terms.
+ *
+ * This is the parameter to a symmetric Dirichlet distribution.
+ *
+ * Note: The topics' distributions over terms are called "beta" in the original LDA paper
+ * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
+ */
+ def topicConcentration: Double
+
+ /**
+ * Shape parameter for random initialization of variational parameter gamma.
+ * Used for variational inference for perplexity and other test-time computations.
+ */
+ protected def gammaShape: Double
+
/**
* Inferred topics, where each topic is represented by a distribution over terms.
* This is a matrix of size vocabSize x k, where each column is a topic.
@@ -163,12 +186,14 @@ abstract class LDAModel private[clustering] extends Saveable {
* This model stores only the inferred topics.
* It may be used for computing topics for new documents, but it may give less accurate answers
* than the [[DistributedLDAModel]].
- *
* @param topics Inferred topics (vocabSize x k matrix).
*/
@Experimental
class LocalLDAModel private[clustering] (
- private val topics: Matrix) extends LDAModel with Serializable {
+ val topics: Matrix,
+ override val docConcentration: Vector,
+ override val topicConcentration: Double,
+ override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable {
override def k: Int = topics.numCols
@@ -189,16 +214,122 @@ class LocalLDAModel private[clustering] (
override protected def formatVersion = "1.0"
override def save(sc: SparkContext, path: String): Unit = {
- LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
+ LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
+ gammaShape)
}
// TODO
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
- // TODO:
- // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
+ /**
+ * Calculate the log variational bound on perplexity. See Equation (16) in original Online
+ * LDA paper.
+ * @param documents test corpus to use for calculating perplexity
+ * @return the log perplexity per word
+ */
+ def logPerplexity(documents: RDD[(Long, Vector)]): Double = {
+ val corpusWords = documents
+ .map { case (_, termCounts) => termCounts.toArray.sum }
+ .sum()
+ val batchVariationalBound = bound(documents, docConcentration,
+ topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize)
+ val perWordBound = batchVariationalBound / corpusWords
+
+ perWordBound
+ }
+
+ /**
+ * Estimate the variational likelihood bound of from `documents`:
+ * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)]
+ * This bound is derived by decomposing the LDA model to:
+ * log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p)
+ * and noting that the KL-divergence D(q|p) >= 0. See Equation (16) in original Online LDA paper.
+ * @param documents a subset of the test corpus
+ * @param alpha document-topic Dirichlet prior parameters
+ * @param eta topic-word Dirichlet prior parameters
+ * @param lambda parameters for variational q(beta | lambda) topic-word distributions
+ * @param gammaShape shape parameter for random initialization of variational q(theta | gamma)
+ * topic mixture distributions
+ * @param k number of topics
+ * @param vocabSize number of unique terms in the entire test corpus
+ */
+ private def bound(
+ documents: RDD[(Long, Vector)],
+ alpha: Vector,
+ eta: Double,
+ lambda: BDM[Double],
+ gammaShape: Double,
+ k: Int,
+ vocabSize: Long): Double = {
+ val brzAlpha = alpha.toBreeze.toDenseVector
+ // transpose because dirichletExpectation normalizes by row and we need to normalize
+ // by topic (columns of lambda)
+ val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
+
+ var score = documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) =>
+ var docScore = 0.0D
+ val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference(
+ termCounts, exp(Elogbeta), brzAlpha, gammaShape, k)
+ val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
+
+ // E[log p(doc | theta, beta)]
+ termCounts.foreachActive { case (idx, count) =>
+ docScore += count * LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t)
+ }
+ // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector
+ docScore += sum((brzAlpha - gammad) :* Elogthetad)
+ docScore += sum(lgamma(gammad) - lgamma(brzAlpha))
+ docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
+
+ docScore
+ }.sum()
+
+ // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar
+ score += sum((eta - lambda) :* Elogbeta)
+ score += sum(lgamma(lambda) - lgamma(eta))
+
+ val sumEta = eta * vocabSize
+ score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
+
+ score
+ }
+
+ /**
+ * Predicts the topic mixture distribution for each document (often called "theta" in the
+ * literature). Returns a vector of zeros for an empty document.
+ *
+ * This uses a variational approximation following Hoffman et al. (2010), where the approximate
+ * distribution is called "gamma." Technically, this method returns this approximation "gamma"
+ * for each document.
+ * @param documents documents to predict topic mixture distributions for
+ * @return An RDD of (document ID, topic mixture distribution for document)
+ */
+ // TODO: declare in LDAModel and override once implemented in DistributedLDAModel
+ def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
+ // Double transpose because dirichletExpectation normalizes by row and we need to normalize
+ // by topic (columns of lambda)
+ val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
+ val docConcentrationBrz = this.docConcentration.toBreeze
+ val gammaShape = this.gammaShape
+ val k = this.k
+
+ documents.map { case (id: Long, termCounts: Vector) =>
+ if (termCounts.numNonzeros == 0) {
+ (id, Vectors.zeros(k))
+ } else {
+ val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
+ termCounts,
+ expElogbeta,
+ docConcentrationBrz,
+ gammaShape,
+ k)
+ (id, Vectors.dense(normalize(gamma, 1.0).toArray))
+ }
+ }
+ }
}
+
@Experimental
object LocalLDAModel extends Loader[LocalLDAModel] {
@@ -212,14 +343,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
// as a Row in data.
case class Data(topic: Vector, index: Int)
- def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
+ def save(
+ sc: SparkContext,
+ path: String,
+ topicsMatrix: Matrix,
+ docConcentration: Vector,
+ topicConcentration: Double,
+ gammaShape: Double): Unit = {
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val k = topicsMatrix.numCols
val metadata = compact(render
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
- ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
+ ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~
+ ("docConcentration" -> docConcentration.toArray.toSeq) ~
+ ("topicConcentration" -> topicConcentration) ~
+ ("gammaShape" -> gammaShape)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
@@ -229,7 +369,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
}
- def load(sc: SparkContext, path: String): LocalLDAModel = {
+ def load(
+ sc: SparkContext,
+ path: String,
+ docConcentration: Vector,
+ topicConcentration: Double,
+ gammaShape: Double): LocalLDAModel = {
val dataPath = Loader.dataPath(path)
val sqlContext = SQLContext.getOrCreate(sc)
val dataFrame = sqlContext.read.parquet(dataPath)
@@ -243,7 +388,10 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
topics.foreach { case Row(vec: Vector, ind: Int) =>
brzTopics(::, ind) := vec.toBreeze
}
- new LocalLDAModel(Matrices.fromBreeze(brzTopics))
+ val topicsMat = Matrices.fromBreeze(brzTopics)
+
+ // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940
+ new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape)
}
}
@@ -252,15 +400,19 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
implicit val formats = DefaultFormats
val expectedK = (metadata \ "k").extract[Int]
val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
+ val docConcentration =
+ Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
+ val topicConcentration = (metadata \ "topicConcentration").extract[Double]
+ val gammaShape = (metadata \ "gammaShape").extract[Double]
val classNameV1_0 = SaveLoadV1_0.thisClassName
val model = (loadedClassName, loadedVersion) match {
case (className, "1.0") if className == classNameV1_0 =>
- SaveLoadV1_0.load(sc, path)
+ SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape)
case _ => throw new Exception(
s"LocalLDAModel.load did not recognize model with (className, format version):" +
- s"($loadedClassName, $loadedVersion). Supported:\n" +
- s" ($classNameV1_0, 1.0)")
+ s"($loadedClassName, $loadedVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
}
val topicsMatrix = model.topicsMatrix
@@ -268,7 +420,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
require(expectedVocabSize == topicsMatrix.numRows,
s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
- s"but got ${topicsMatrix.numRows}")
+ s"but got ${topicsMatrix.numRows}")
model
}
}
@@ -282,28 +434,25 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
* than the [[LocalLDAModel]].
*/
@Experimental
-class DistributedLDAModel private (
+class DistributedLDAModel private[clustering] (
private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
private[clustering] val globalTopicTotals: LDA.TopicCounts,
val k: Int,
val vocabSize: Int,
- private[clustering] val docConcentration: Double,
- private[clustering] val topicConcentration: Double,
+ override val docConcentration: Vector,
+ override val topicConcentration: Double,
+ override protected[clustering] val gammaShape: Double,
private[spark] val iterationTimes: Array[Double]) extends LDAModel {
import LDA._
- private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = {
- this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
- state.topicConcentration, iterationTimes)
- }
-
/**
* Convert model to a local model.
* The local model stores the inferred topics but not the topic distributions for training
* documents.
*/
- def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix)
+ def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
+ gammaShape)
/**
* Inferred topics, where each topic is represented by a distribution over terms.
@@ -375,8 +524,9 @@ class DistributedLDAModel private (
* hyperparameters.
*/
lazy val logLikelihood: Double = {
- val eta = topicConcentration
- val alpha = docConcentration
+ // TODO: generalize this for asymmetric (non-scalar) alpha
+ val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
+ val eta = this.topicConcentration
assert(eta > 1.0)
assert(alpha > 1.0)
val N_k = globalTopicTotals
@@ -400,8 +550,9 @@ class DistributedLDAModel private (
* log P(topics, topic distributions for docs | alpha, eta)
*/
lazy val logPrior: Double = {
- val eta = topicConcentration
- val alpha = docConcentration
+ // TODO: generalize this for asymmetric (non-scalar) alpha
+ val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
+ val eta = this.topicConcentration
// Term vertices: Compute phi_{wk}. Use to compute prior log probability.
// Doc vertex: Compute theta_{kj}. Use to compute prior log probability.
val N_k = globalTopicTotals
@@ -412,12 +563,12 @@ class DistributedLDAModel private (
val N_wk = vertex._2
val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0)
val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
- (eta - 1.0) * brzSum(phi_wk.map(math.log))
+ (eta - 1.0) * sum(phi_wk.map(math.log))
} else {
val N_kj = vertex._2
val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0)
val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
- (alpha - 1.0) * brzSum(theta_kj.map(math.log))
+ (alpha - 1.0) * sum(theta_kj.map(math.log))
}
}
graph.vertices.aggregate(0.0)(seqOp, _ + _)
@@ -448,7 +599,7 @@ class DistributedLDAModel private (
override def save(sc: SparkContext, path: String): Unit = {
DistributedLDAModel.SaveLoadV1_0.save(
sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
- iterationTimes)
+ iterationTimes, gammaShape)
}
}
@@ -460,7 +611,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val thisFormatVersion = "1.0"
- val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"
+ val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel"
// Store globalTopicTotals as a Vector.
case class Data(globalTopicTotals: Vector)
@@ -478,17 +629,20 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
globalTopicTotals: LDA.TopicCounts,
k: Int,
vocabSize: Int,
- docConcentration: Double,
+ docConcentration: Vector,
topicConcentration: Double,
- iterationTimes: Array[Double]): Unit = {
+ iterationTimes: Array[Double],
+ gammaShape: Double): Unit = {
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val metadata = compact(render
- (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
- ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~
- ("topicConcentration" -> topicConcentration) ~
- ("iterationTimes" -> iterationTimes.toSeq)))
+ (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("k" -> k) ~ ("vocabSize" -> vocabSize) ~
+ ("docConcentration" -> docConcentration.toArray.toSeq) ~
+ ("topicConcentration" -> topicConcentration) ~
+ ("iterationTimes" -> iterationTimes.toSeq) ~
+ ("gammaShape" -> gammaShape)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
@@ -510,9 +664,10 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
sc: SparkContext,
path: String,
vocabSize: Int,
- docConcentration: Double,
+ docConcentration: Vector,
topicConcentration: Double,
- iterationTimes: Array[Double]): DistributedLDAModel = {
+ iterationTimes: Array[Double],
+ gammaShape: Double): DistributedLDAModel = {
val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
@@ -536,7 +691,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
- docConcentration, topicConcentration, iterationTimes)
+ docConcentration, topicConcentration, gammaShape, iterationTimes)
}
}
@@ -546,32 +701,35 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
implicit val formats = DefaultFormats
val expectedK = (metadata \ "k").extract[Int]
val vocabSize = (metadata \ "vocabSize").extract[Int]
- val docConcentration = (metadata \ "docConcentration").extract[Double]
+ val docConcentration =
+ Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
val topicConcentration = (metadata \ "topicConcentration").extract[Double]
val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
- val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+ val gammaShape = (metadata \ "gammaShape").extract[Double]
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
val model = (loadedClassName, loadedVersion) match {
case (className, "1.0") if className == classNameV1_0 => {
- DistributedLDAModel.SaveLoadV1_0.load(
- sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray)
+ DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration,
+ topicConcentration, iterationTimes.toArray, gammaShape)
}
case _ => throw new Exception(
s"DistributedLDAModel.load did not recognize model with (className, format version):" +
- s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)")
+ s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)")
}
require(model.vocabSize == vocabSize,
s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
require(model.docConcentration == docConcentration,
s"DistributedLDAModel requires $docConcentration docConcentration, " +
- s"got ${model.docConcentration} docConcentration")
+ s"got ${model.docConcentration} docConcentration")
require(model.topicConcentration == topicConcentration,
s"DistributedLDAModel requires $topicConcentration docConcentration, " +
- s"got ${model.topicConcentration} docConcentration")
+ s"got ${model.topicConcentration} docConcentration")
require(expectedK == model.k,
s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
model
}
}
+
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index f4170a3d98dd8..d6f8b29a43dfd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
import java.util.Random
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
-import breeze.numerics.{abs, digamma, exp}
+import breeze.numerics.{abs, exp}
import breeze.stats.distributions.{Gamma, RandBasis}
import org.apache.spark.annotation.DeveloperApi
@@ -142,8 +142,9 @@ final class EMLDAOptimizer extends LDAOptimizer {
this.k = k
this.vocabSize = docs.take(1).head._2.size
this.checkpointInterval = lda.getCheckpointInterval
- this.graphCheckpointer = new
- PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
+ this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
+ checkpointInterval, graph.vertices.sparkContext)
+ this.graphCheckpointer.update(this.graph)
this.globalTopicTotals = computeGlobalTopicTotals()
this
}
@@ -188,7 +189,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
// Update the vertex descriptors with the new counts.
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
graph = newGraph
- graphCheckpointer.updateGraph(newGraph)
+ graphCheckpointer.update(newGraph)
globalTopicTotals = computeGlobalTopicTotals()
this
}
@@ -208,7 +209,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
this.graphCheckpointer.deleteAllCheckpoints()
- new DistributedLDAModel(this, iterationTimes)
+ // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal
+ // conversion
+ new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
+ Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
+ 100, iterationTimes)
}
}
@@ -385,71 +390,52 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
iteration += 1
val k = this.k
val vocabSize = this.vocabSize
- val Elogbeta = dirichletExpectation(lambda).t
- val expElogbeta = exp(Elogbeta)
+ val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t
val alpha = this.alpha.toBreeze
val gammaShape = this.gammaShape
- val stats: RDD[BDM[Double]] = batch.mapPartitions { docs =>
- val stat = BDM.zeros[Double](k, vocabSize)
- docs.foreach { doc =>
- val termCounts = doc._2
- val (ids: List[Int], cts: Array[Double]) = termCounts match {
- case v: DenseVector => ((0 until v.size).toList, v.values)
- case v: SparseVector => (v.indices.toList, v.values)
- case v => throw new IllegalArgumentException("Online LDA does not support vector type "
- + v.getClass)
- }
- if (!ids.isEmpty) {
-
- // Initialize the variational distribution q(theta|gamma) for the mini-batch
- val gammad: BDV[Double] =
- new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
- val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K
- val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K
-
- val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
- var meanchange = 1D
- val ctsVector = new BDV[Double](cts) // ids
-
- // Iterate between gamma and phi until convergence
- while (meanchange > 1e-3) {
- val lastgamma = gammad.copy
- // K K * ids ids
- gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
- expElogthetad := exp(digamma(gammad) - digamma(sum(gammad)))
- phinorm := expElogbetad * expElogthetad :+ 1e-100
- meanchange = sum(abs(gammad - lastgamma)) / k
- }
+ val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs =>
+ val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
- stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
+ val stat = BDM.zeros[Double](k, vocabSize)
+ var gammaPart = List[BDV[Double]]()
+ nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) =>
+ val ids: List[Int] = termCounts match {
+ case v: DenseVector => (0 until v.size).toList
+ case v: SparseVector => v.indices.toList
}
+ val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference(
+ termCounts, expElogbeta, alpha, gammaShape, k)
+ stat(::, ids) := stat(::, ids).toDenseMatrix + sstats
+ gammaPart = gammad :: gammaPart
}
- Iterator(stat)
+ Iterator((stat, gammaPart))
}
-
- val statsSum: BDM[Double] = stats.reduce(_ += _)
+ val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
+ val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
+ stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
val batchResult = statsSum :* expElogbeta.t
// Note that this is an optimization to avoid batch.count
- update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
+ updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
this
}
- override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
- new LocalLDAModel(Matrices.fromBreeze(lambda).transpose)
- }
-
/**
* Update lambda based on the batch submitted. batchSize can be different for each iteration.
*/
- private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = {
+ private def updateLambda(stat: BDM[Double], batchSize: Int): Unit = {
// weight of the mini-batch.
- val weight = math.pow(getTau0 + iter, -getKappa)
+ val weight = rho()
// Update lambda based on documents.
- lambda = lambda * (1 - weight) +
- (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight
+ lambda := (1 - weight) * lambda +
+ weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta)
+ }
+
+ /** Calculates learning rate rho, which decays as a function of [[iteration]] */
+ private def rho(): Double = {
+ math.pow(getTau0 + this.iteration, -getKappa)
}
/**
@@ -463,15 +449,57 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
new BDM[Double](col, row, temp).t
}
+ override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
+ new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape)
+ }
+
+}
+
+/**
+ * Serializable companion object containing helper methods and shared code for
+ * [[OnlineLDAOptimizer]] and [[LocalLDAModel]].
+ */
+private[clustering] object OnlineLDAOptimizer {
/**
- * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
- * uses digamma which is accurate but expensive.
+ * Uses variational inference to infer the topic distribution `gammad` given the term counts
+ * for a document. `termCounts` must contain at least one non-zero entry, otherwise Breeze will
+ * throw a BLAS error.
+ *
+ * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001)
+ * avoids explicit computation of variational parameter `phi`.
+ * @see [[http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.7566]]
*/
- private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
- val rowSum = sum(alpha(breeze.linalg.*, ::))
- val digAlpha = digamma(alpha)
- val digRowSum = digamma(rowSum)
- val result = digAlpha(::, breeze.linalg.*) - digRowSum
- result
+ private[clustering] def variationalTopicInference(
+ termCounts: Vector,
+ expElogbeta: BDM[Double],
+ alpha: breeze.linalg.Vector[Double],
+ gammaShape: Double,
+ k: Int): (BDV[Double], BDM[Double]) = {
+ val (ids: List[Int], cts: Array[Double]) = termCounts match {
+ case v: DenseVector => ((0 until v.size).toList, v.values)
+ case v: SparseVector => (v.indices.toList, v.values)
+ }
+ // Initialize the variational distribution q(theta|gamma) for the mini-batch
+ val gammad: BDV[Double] =
+ new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
+ val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K
+ val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K
+
+ val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
+ var meanchange = 1D
+ val ctsVector = new BDV[Double](cts) // ids
+
+ // Iterate between gamma and phi until convergence
+ while (meanchange > 1e-3) {
+ val lastgamma = gammad.copy
+ // K K * ids ids
+ gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
+ expElogthetad := exp(LDAUtils.dirichletExpectation(gammad))
+ phinorm := expElogbetad * expElogthetad :+ 1e-100
+ meanchange = sum(abs(gammad - lastgamma)) / k
+ }
+
+ val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
+ (gammad, sstatsd)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
new file mode 100644
index 0000000000000..f7e5ce1665fe6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.clustering
+
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum}
+import breeze.numerics._
+
+/**
+ * Utility methods for LDA.
+ */
+object LDAUtils {
+ /**
+ * Log Sum Exp with overflow protection using the identity:
+ * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\}
+ */
+ private[clustering] def logSumExp(x: BDV[Double]): Double = {
+ val a = max(x)
+ a + log(sum(exp(x :- a)))
+ }
+
+ /**
+ * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
+ * uses [[breeze.numerics.digamma]] which is accurate but expensive.
+ */
+ private[clustering] def dirichletExpectation(alpha: BDV[Double]): BDV[Double] = {
+ digamma(alpha) - digamma(sum(alpha))
+ }
+
+ /**
+ * Computes [[dirichletExpectation()]] row-wise, assuming each row of alpha are
+ * Dirichlet parameters.
+ */
+ private[clustering] def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
+ val rowSum = sum(alpha(breeze.linalg.*, ::))
+ val digAlpha = digamma(alpha)
+ val digRowSum = digamma(rowSum)
+ val result = digAlpha(::, breeze.linalg.*) - digRowSum
+ result
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
index 7ead6327486cc..0ea792081086d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
minCount: Long,
maxPatternLength: Int,
prefixes: List[Int],
- database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
+ database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
@@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}
- def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
+ def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
database
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
@@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
*/
private def getFreqItemAndCounts(
minCount: Long,
- database: Array[Array[Int]]): mutable.Map[Int, Long] = {
+ database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 6f52db7b073ae..e6752332cdeeb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -17,6 +17,8 @@
package org.apache.spark.mllib.fpm
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
@@ -43,28 +45,45 @@ class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends Logging with Serializable {
+ /**
+ * The maximum number of items allowed in a projected database before local processing. If a
+ * projected database exceeds this size, another iteration of distributed PrefixSpan is run.
+ */
+ // TODO: make configurable with a better default value, 10000 may be too small
+ private val maxLocalProjDBSize: Long = 10000
+
/**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: `10`}.
*/
def this() = this(0.1, 10)
+ /**
+ * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
+ * frequent).
+ */
+ def getMinSupport: Double = this.minSupport
+
/**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
- require(minSupport >= 0 && minSupport <= 1,
- "The minimum support value must be between 0 and 1, including 0 and 1.")
+ require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
this.minSupport = minSupport
this
}
+ /**
+ * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
+ */
+ def getMaxPatternLength: Double = this.maxPatternLength
+
/**
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
- require(maxPatternLength >= 1,
- "The maximum pattern length value must be greater than 0.")
+ // TODO: support unbounded pattern length when maxPatternLength = 0
+ require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
this.maxPatternLength = maxPatternLength
this
}
@@ -78,81 +97,153 @@ class PrefixSpan private (
* the value of pair is the pattern's count.
*/
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
+ val sc = sequences.sparkContext
+
if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
- val minCount = getMinCount(sequences)
- val lengthOnePatternsAndCounts =
- getFreqItemAndCounts(minCount, sequences).collect()
- val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
- lengthOnePatternsAndCounts.map(_._1), sequences)
- val groupedProjectedDatabase = prefixAndProjectedDatabase
- .map(x => (x._1.toSeq, x._2))
- .groupByKey()
- .map(x => (x._1.toArray, x._2.toArray))
- val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
- val lengthOnePatternsAndCountsRdd =
- sequences.sparkContext.parallelize(
- lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
- val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
- allPatterns
+
+ // Convert min support to a min number of transactions for this dataset
+ val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
+
+ // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
+ val freqItemCounts = sequences
+ .flatMap(seq => seq.distinct.map(item => (item, 1L)))
+ .reduceByKey(_ + _)
+ .filter(_._2 >= minCount)
+ .collect()
+
+ // Pairs of (length 1 prefix, suffix consisting of frequent items)
+ val itemSuffixPairs = {
+ val freqItems = freqItemCounts.map(_._1).toSet
+ sequences.flatMap { seq =>
+ val filteredSeq = seq.filter(freqItems.contains(_))
+ freqItems.flatMap { item =>
+ val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
+ candidateSuffix match {
+ case suffix if !suffix.isEmpty => Some((List(item), suffix))
+ case _ => None
+ }
+ }
+ }
+ }
+
+ // Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
+ // frequent length-one prefixes)
+ var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))
+
+ // Remaining work to be locally and distributively processed respectfully
+ var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
+
+ // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
+ // projected database sizes <= `maxLocalProjDBSize`)
+ while (pairsForDistributed.count() != 0) {
+ val (nextPatternAndCounts, nextPrefixSuffixPairs) =
+ extendPrefixes(minCount, pairsForDistributed)
+ pairsForDistributed.unpersist()
+ val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
+ pairsForDistributed = largerPairsPart
+ pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
+ pairsForLocal ++= smallerPairsPart
+ resultsAccumulator ++= nextPatternAndCounts.collect()
+ }
+
+ // Process the small projected databases locally
+ val remainingResults = getPatternsInLocal(
+ minCount, sc.parallelize(pairsForLocal, 1).groupByKey())
+
+ (sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
+ .map { case (pattern, count) => (pattern.toArray, count) }
}
+
/**
- * Get the minimum count (sequences count * minSupport).
- * @param sequences input data set, contains a set of sequences,
- * @return minimum count,
+ * Partitions the prefix-suffix pairs by projected database size.
+ * @param prefixSuffixPairs prefix (length n) and suffix pairs,
+ * @return prefix-suffix pairs partitioned by whether their projected database size is <= or
+ * greater than [[maxLocalProjDBSize]]
*/
- private def getMinCount(sequences: RDD[Array[Int]]): Long = {
- if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
+ private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
+ : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
+ val prefixToSuffixSize = prefixSuffixPairs
+ .aggregateByKey(0)(
+ seqOp = { case (count, suffix) => count + suffix.length },
+ combOp = { _ + _ })
+ val smallPrefixes = prefixToSuffixSize
+ .filter(_._2 <= maxLocalProjDBSize)
+ .keys
+ .collect()
+ .toSet
+ val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
+ val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
+ (small.collect(), large)
}
/**
- * Generates frequent items by filtering the input data using minimal count level.
- * @param minCount the absolute minimum count
- * @param sequences original sequences data
- * @return array of item and count pair
+ * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
+ * and remaining work.
+ * @param minCount minimum count
+ * @param prefixSuffixPairs prefix (length N) and suffix pairs,
+ * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
+ * prefix, corresponding suffix) pairs.
*/
- private def getFreqItemAndCounts(
+ private def extendPrefixes(
minCount: Long,
- sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
- sequences.flatMap(_.distinct.map((_, 1L)))
+ prefixSuffixPairs: RDD[(List[Int], Array[Int])])
+ : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {
+
+ // (length N prefix, item from suffix) pairs and their corresponding number of occurrences
+ // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
+ val prefixItemPairAndCounts = prefixSuffixPairs
+ .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
- }
- /**
- * Get the frequent prefixes' projected database.
- * @param frequentPrefixes frequent prefixes
- * @param sequences sequences data
- * @return prefixes and projected database
- */
- private def getPrefixAndProjectedDatabase(
- frequentPrefixes: Array[Int],
- sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
- val filteredSequences = sequences.map { p =>
- p.filter (frequentPrefixes.contains(_) )
- }
- filteredSequences.flatMap { x =>
- frequentPrefixes.map { y =>
- val sub = LocalPrefixSpan.getSuffix(y, x)
- (Array(y), sub)
- }.filter(_._2.nonEmpty)
- }
+ // Map from prefix to set of possible next items from suffix
+ val prefixToNextItems = prefixItemPairAndCounts
+ .keys
+ .groupByKey()
+ .mapValues(_.toSet)
+ .collect()
+ .toMap
+
+
+ // Frequent patterns with length N+1 and their corresponding counts
+ val extendedPrefixAndCounts = prefixItemPairAndCounts
+ .map { case ((prefix, item), count) => (item :: prefix, count) }
+
+ // Remaining work, all prefixes will have length N+1
+ val extendedPrefixAndSuffix = prefixSuffixPairs
+ .filter(x => prefixToNextItems.contains(x._1))
+ .flatMap { case (prefix, suffix) =>
+ val frequentNextItems = prefixToNextItems(prefix)
+ val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
+ frequentNextItems.flatMap { item =>
+ LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
+ case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
+ case _ => None
+ }
+ }
+ }
+
+ (extendedPrefixAndCounts, extendedPrefixAndSuffix)
}
/**
- * calculate the patterns in local.
+ * Calculate the patterns in local.
* @param minCount the absolute minimum count
- * @param data patterns and projected sequences data data
+ * @param data prefixes and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
minCount: Long,
- data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
- data.flatMap { case (prefix, projDB) =>
- LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
- .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
+ data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
+ data.flatMap {
+ case (prefix, projDB) =>
+ LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
+ .map { case (pattern: List[Int], count: Long) =>
+ (pattern.reverse, count)
+ }
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
new file mode 100644
index 0000000000000..72d3aabc9b1f4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
+ * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to
+ * the distributed data type (RDD, Graph, etc.).
+ *
+ * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
+ * as well as unpersisting and removing checkpoint files.
+ *
+ * Users should call update() when a new Dataset has been created,
+ * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are
+ * responsible for materializing the Dataset to ensure that persisting and checkpointing actually
+ * occur.
+ *
+ * When update() is called, this does the following:
+ * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
+ * - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
+ * - If using checkpointing and the checkpoint interval has been reached,
+ * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
+ * - Remove older checkpoints.
+ *
+ * WARNINGS:
+ * - This class should NOT be copied (since copies may conflict on which Datasets should be
+ * checkpointed).
+ * - This class removes checkpoint files once later Datasets have been checkpointed.
+ * However, references to the older Datasets will still return isCheckpointed = true.
+ *
+ * @param checkpointInterval Datasets will be checkpointed at this interval
+ * @param sc SparkContext for the Datasets given to this checkpointer
+ * @tparam T Dataset type, such as RDD[Double]
+ */
+private[mllib] abstract class PeriodicCheckpointer[T](
+ val checkpointInterval: Int,
+ val sc: SparkContext) extends Logging {
+
+ /** FIFO queue of past checkpointed Datasets */
+ private val checkpointQueue = mutable.Queue[T]()
+
+ /** FIFO queue of past persisted Datasets */
+ private val persistedQueue = mutable.Queue[T]()
+
+ /** Number of times [[update()]] has been called */
+ private var updateCount = 0
+
+ /**
+ * Update with a new Dataset. Handle persistence and checkpointing as needed.
+ * Since this handles persistence and checkpointing, this should be called before the Dataset
+ * has been materialized.
+ *
+ * @param newData New Dataset created from previous Datasets in the lineage.
+ */
+ def update(newData: T): Unit = {
+ persist(newData)
+ persistedQueue.enqueue(newData)
+ // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
+ // Users should call [[update()]] when a new Dataset has been created,
+ // before the Dataset has been materialized.
+ while (persistedQueue.size > 3) {
+ val dataToUnpersist = persistedQueue.dequeue()
+ unpersist(dataToUnpersist)
+ }
+ updateCount += 1
+
+ // Handle checkpointing (after persisting)
+ if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
+ // Add new checkpoint before removing old checkpoints.
+ checkpoint(newData)
+ checkpointQueue.enqueue(newData)
+ // Remove checkpoints before the latest one.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // Delete the oldest checkpoint only if the next checkpoint exists.
+ if (isCheckpointed(checkpointQueue.head)) {
+ removeCheckpointFile()
+ } else {
+ canDelete = false
+ }
+ }
+ }
+ }
+
+ /** Checkpoint the Dataset */
+ protected def checkpoint(data: T): Unit
+
+ /** Return true iff the Dataset is checkpointed */
+ protected def isCheckpointed(data: T): Boolean
+
+ /**
+ * Persist the Dataset.
+ * Note: This should handle checking the current [[StorageLevel]] of the Dataset.
+ */
+ protected def persist(data: T): Unit
+
+ /** Unpersist the Dataset */
+ protected def unpersist(data: T): Unit
+
+ /** Get list of checkpoint files for this given Dataset */
+ protected def getCheckpointFiles(data: T): Iterable[String]
+
+ /**
+ * Call this at the end to delete any remaining checkpoint files.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.nonEmpty) {
+ removeCheckpointFile()
+ }
+ }
+
+ /**
+ * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
+ * This prints a warning but does not fail if the files cannot be removed.
+ */
+ private def removeCheckpointFile(): Unit = {
+ val old = checkpointQueue.dequeue()
+ // Since the old checkpoint is not deleted by Spark, we manually delete it.
+ val fs = FileSystem.get(sc.hadoopConfiguration)
+ getCheckpointFiles(old).foreach { checkpointFile =>
+ try {
+ fs.delete(new Path(checkpointFile), true)
+ } catch {
+ case e: Exception =>
+ logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
+ checkpointFile)
+ }
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
index 6e5dd119dd653..11a059536c50c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
@@ -17,11 +17,7 @@
package org.apache.spark.mllib.impl
-import scala.collection.mutable
-
-import org.apache.hadoop.fs.{Path, FileSystem}
-
-import org.apache.spark.Logging
+import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel
@@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel
* Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
* unpersisting and removing checkpoint files.
*
- * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created,
+ * Users should call update() when a new graph has been created,
* before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are
* responsible for materializing the graph to ensure that persisting and checkpointing actually
* occur.
*
- * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following:
+ * When update() is called, this does the following:
* - Persist new graph (if not yet persisted), and put in queue of persisted graphs.
* - Unpersist graphs from queue until there are at most 3 persisted graphs.
* - If using checkpointing and the checkpoint interval has been reached,
@@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel
* Example usage:
* {{{
* val (graph1, graph2, graph3, ...) = ...
- * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2)
+ * val cp = new PeriodicGraphCheckpointer(2, sc)
* graph1.vertices.count(); graph1.edges.count()
* // persisted: graph1
* cp.updateGraph(graph2)
@@ -73,99 +69,30 @@ import org.apache.spark.storage.StorageLevel
* // checkpointed: graph4
* }}}
*
- * @param currentGraph Initial graph
* @param checkpointInterval Graphs will be checkpointed at this interval
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
- * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib.
+ * TODO: Move this out of MLlib?
*/
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
- var currentGraph: Graph[VD, ED],
- val checkpointInterval: Int) extends Logging {
-
- /** FIFO queue of past checkpointed RDDs */
- private val checkpointQueue = mutable.Queue[Graph[VD, ED]]()
-
- /** FIFO queue of past persisted RDDs */
- private val persistedQueue = mutable.Queue[Graph[VD, ED]]()
-
- /** Number of times [[updateGraph()]] has been called */
- private var updateCount = 0
-
- /**
- * Spark Context for the Graphs given to this checkpointer.
- * NOTE: This code assumes that only one SparkContext is used for the given graphs.
- */
- private val sc = currentGraph.vertices.sparkContext
+ checkpointInterval: Int,
+ sc: SparkContext)
+ extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
- updateGraph(currentGraph)
+ override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint()
- /**
- * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed.
- * Since this handles persistence and checkpointing, this should be called before the graph
- * has been materialized.
- *
- * @param newGraph New graph created from previous graphs in the lineage.
- */
- def updateGraph(newGraph: Graph[VD, ED]): Unit = {
- if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) {
- newGraph.persist()
- }
- persistedQueue.enqueue(newGraph)
- // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class:
- // Users should call [[updateGraph()]] when a new graph has been created,
- // before the graph has been materialized.
- while (persistedQueue.size > 3) {
- val graphToUnpersist = persistedQueue.dequeue()
- graphToUnpersist.unpersist(blocking = false)
- }
- updateCount += 1
+ override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed
- // Handle checkpointing (after persisting)
- if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
- // Add new checkpoint before removing old checkpoints.
- newGraph.checkpoint()
- checkpointQueue.enqueue(newGraph)
- // Remove checkpoints before the latest one.
- var canDelete = true
- while (checkpointQueue.size > 1 && canDelete) {
- // Delete the oldest checkpoint only if the next checkpoint exists.
- if (checkpointQueue.get(1).get.isCheckpointed) {
- removeCheckpointFile()
- } else {
- canDelete = false
- }
- }
+ override protected def persist(data: Graph[VD, ED]): Unit = {
+ if (data.vertices.getStorageLevel == StorageLevel.NONE) {
+ data.persist()
}
}
- /**
- * Call this at the end to delete any remaining checkpoint files.
- */
- def deleteAllCheckpoints(): Unit = {
- while (checkpointQueue.size > 0) {
- removeCheckpointFile()
- }
- }
+ override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false)
- /**
- * Dequeue the oldest checkpointed Graph, and remove its checkpoint files.
- * This prints a warning but does not fail if the files cannot be removed.
- */
- private def removeCheckpointFile(): Unit = {
- val old = checkpointQueue.dequeue()
- // Since the old checkpoint is not deleted by Spark, we manually delete it.
- val fs = FileSystem.get(sc.hadoopConfiguration)
- old.getCheckpointFiles.foreach { checkpointFile =>
- try {
- fs.delete(new Path(checkpointFile), true)
- } catch {
- case e: Exception =>
- logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " +
- checkpointFile)
- }
- }
+ override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = {
+ data.getCheckpointFiles
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
new file mode 100644
index 0000000000000..f31ed2aa90a64
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This class helps with persisting and checkpointing RDDs.
+ * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
+ * unpersisting and removing checkpoint files.
+ *
+ * Users should call update() when a new RDD has been created,
+ * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are
+ * responsible for materializing the RDD to ensure that persisting and checkpointing actually
+ * occur.
+ *
+ * When update() is called, this does the following:
+ * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs.
+ * - Unpersist RDDs from queue until there are at most 3 persisted RDDs.
+ * - If using checkpointing and the checkpoint interval has been reached,
+ * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs.
+ * - Remove older checkpoints.
+ *
+ * WARNINGS:
+ * - This class should NOT be copied (since copies may conflict on which RDDs should be
+ * checkpointed).
+ * - This class removes checkpoint files once later RDDs have been checkpointed.
+ * However, references to the older RDDs will still return isCheckpointed = true.
+ *
+ * Example usage:
+ * {{{
+ * val (rdd1, rdd2, rdd3, ...) = ...
+ * val cp = new PeriodicRDDCheckpointer(2, sc)
+ * rdd1.count();
+ * // persisted: rdd1
+ * cp.update(rdd2)
+ * rdd2.count();
+ * // persisted: rdd1, rdd2
+ * // checkpointed: rdd2
+ * cp.update(rdd3)
+ * rdd3.count();
+ * // persisted: rdd1, rdd2, rdd3
+ * // checkpointed: rdd2
+ * cp.update(rdd4)
+ * rdd4.count();
+ * // persisted: rdd2, rdd3, rdd4
+ * // checkpointed: rdd4
+ * cp.update(rdd5)
+ * rdd5.count();
+ * // persisted: rdd3, rdd4, rdd5
+ * // checkpointed: rdd4
+ * }}}
+ *
+ * @param checkpointInterval RDDs will be checkpointed at this interval
+ * @tparam T RDD element type
+ *
+ * TODO: Move this out of MLlib?
+ */
+private[mllib] class PeriodicRDDCheckpointer[T](
+ checkpointInterval: Int,
+ sc: SparkContext)
+ extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
+
+ override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint()
+
+ override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed
+
+ override protected def persist(data: RDD[T]): Unit = {
+ if (data.getStorageLevel == StorageLevel.NONE) {
+ data.persist()
+ }
+ }
+
+ override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false)
+
+ override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = {
+ data.getCheckpointFile.map(x => x)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index d82ba2456df1a..88914fa875990 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setByte(0, 0)
row.setInt(1, sm.numRows)
row.setInt(2, sm.numCols)
- row.update(3, sm.colPtrs.toSeq)
- row.update(4, sm.rowIndices.toSeq)
- row.update(5, sm.values.toSeq)
+ row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
+ row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
+ row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, sm.isTransposed)
case dm: DenseMatrix =>
@@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setInt(2, dm.numCols)
row.setNullAt(3)
row.setNullAt(4)
- row.update(5, dm.values.toSeq)
+ row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, dm.isTransposed)
}
row
@@ -179,14 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
val tpe = row.getByte(0)
val numRows = row.getInt(1)
val numCols = row.getInt(2)
- val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray
+ val values = row.getArray(5).toArray.map(_.asInstanceOf[Double])
val isTransposed = row.getBoolean(6)
tpe match {
case 0 =>
- val colPtrs =
- row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray
- val rowIndices =
- row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray
+ val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int])
+ val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int])
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
case 1 =>
new DenseMatrix(numRows, numCols, values, isTransposed)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
index 9669c364bad8f..b416d50a5631e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
@@ -25,3 +25,11 @@ import org.apache.spark.annotation.Experimental
*/
@Experimental
case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType)
+
+/**
+ * :: Experimental ::
+ * Represents QR factors.
+ */
+@Experimental
+case class QRDecomposition[UType, VType](Q: UType, R: VType)
+
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 0cb28d78bec05..89a1818db0d1d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -187,15 +187,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
- row.update(2, indices.toSeq)
- row.update(3, values.toSeq)
+ row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
+ row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
- row.update(3, values.toSeq)
+ row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
}
}
@@ -209,14 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
tpe match {
case 0 =>
val size = row.getInt(1)
- val indices =
- row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray
- val values =
- row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray
+ val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int])
+ val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new SparseVector(size, indices, values)
case 1 =>
- val values =
- row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray
+ val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new DenseVector(values)
}
}
@@ -637,6 +634,8 @@ class SparseVector(
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
s" ${values.length} values.")
+ require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
+ s"which exceeds the specified vector size ${size}.")
override def toString: String =
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 1626da9c3d2ee..bfc90c9ef8527 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -22,7 +22,7 @@ import java.util.Arrays
import scala.collection.mutable.ListBuffer
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy,
- svd => brzSvd}
+ svd => brzSvd, MatrixSingularException, inv}
import breeze.numerics.{sqrt => brzSqrt}
import com.github.fommil.netlib.BLAS.{getInstance => blas}
@@ -497,6 +497,50 @@ class RowMatrix(
columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma)
}
+ /**
+ * Compute QR decomposition for [[RowMatrix]]. The implementation is designed to optimize the QR
+ * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape.
+ * Reference:
+ * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce
+ * architectures" ([[http://dx.doi.org/10.1145/1996092.1996103]])
+ *
+ * @param computeQ whether to computeQ
+ * @return QRDecomposition(Q, R), Q = null if computeQ = false.
+ */
+ def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = {
+ val col = numCols().toInt
+ // split rows horizontally into smaller matrices, and compute QR for each of them
+ val blockQRs = rows.glom().map { partRows =>
+ val bdm = BDM.zeros[Double](partRows.length, col)
+ var i = 0
+ partRows.foreach { row =>
+ bdm(i, ::) := row.toBreeze.t
+ i += 1
+ }
+ breeze.linalg.qr.reduced(bdm).r
+ }
+
+ // combine the R part from previous results vertically into a tall matrix
+ val combinedR = blockQRs.treeReduce{ (r1, r2) =>
+ val stackedR = BDM.vertcat(r1, r2)
+ breeze.linalg.qr.reduced(stackedR).r
+ }
+ val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix)
+ val finalQ = if (computeQ) {
+ try {
+ val invR = inv(combinedR)
+ this.multiply(Matrices.fromBreeze(invR))
+ } catch {
+ case err: MatrixSingularException =>
+ logWarning("R is not invertible and return Q as null")
+ null
+ }
+ } else {
+ null
+ }
+ QRDecomposition(finalQ, finalR)
+ }
+
/**
* Find all similar columns using the DIMSUM sampling algorithm, described in two papers
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 93290e6508529..56c549ef99cb7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -26,6 +26,7 @@ import org.apache.spark.storage.StorageLevel
/**
* A more compact class to represent a rating than Tuple3[Int, Int, Double].
+ * @since 0.8.0
*/
case class Rating(user: Int, product: Int, rating: Double)
@@ -254,6 +255,7 @@ class ALS private (
/**
* Top-level methods for calling Alternating Least Squares (ALS) matrix factorization.
+ * @since 0.8.0
*/
object ALS {
/**
@@ -269,6 +271,7 @@ object ALS {
* @param lambda regularization factor (recommended: 0.01)
* @param blocks level of parallelism to split computation into
* @param seed random seed
+ * @since 0.9.1
*/
def train(
ratings: RDD[Rating],
@@ -293,6 +296,7 @@ object ALS {
* @param iterations number of iterations of ALS (recommended: 10-20)
* @param lambda regularization factor (recommended: 0.01)
* @param blocks level of parallelism to split computation into
+ * @since 0.8.0
*/
def train(
ratings: RDD[Rating],
@@ -315,6 +319,7 @@ object ALS {
* @param rank number of features to use
* @param iterations number of iterations of ALS (recommended: 10-20)
* @param lambda regularization factor (recommended: 0.01)
+ * @since 0.8.0
*/
def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double)
: MatrixFactorizationModel = {
@@ -331,6 +336,7 @@ object ALS {
* @param ratings RDD of (userID, productID, rating) pairs
* @param rank number of features to use
* @param iterations number of iterations of ALS (recommended: 10-20)
+ * @since 0.8.0
*/
def train(ratings: RDD[Rating], rank: Int, iterations: Int)
: MatrixFactorizationModel = {
@@ -351,6 +357,7 @@ object ALS {
* @param blocks level of parallelism to split computation into
* @param alpha confidence parameter
* @param seed random seed
+ * @since 0.8.1
*/
def trainImplicit(
ratings: RDD[Rating],
@@ -377,6 +384,7 @@ object ALS {
* @param lambda regularization factor (recommended: 0.01)
* @param blocks level of parallelism to split computation into
* @param alpha confidence parameter
+ * @since 0.8.1
*/
def trainImplicit(
ratings: RDD[Rating],
@@ -401,6 +409,7 @@ object ALS {
* @param iterations number of iterations of ALS (recommended: 10-20)
* @param lambda regularization factor (recommended: 0.01)
* @param alpha confidence parameter
+ * @since 0.8.1
*/
def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double)
: MatrixFactorizationModel = {
@@ -418,6 +427,7 @@ object ALS {
* @param ratings RDD of (userID, productID, rating) pairs
* @param rank number of features to use
* @param iterations number of iterations of ALS (recommended: 10-20)
+ * @since 0.8.1
*/
def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int)
: MatrixFactorizationModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 43d219a49cf4e..261ca9cef0c5b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -49,6 +49,7 @@ import org.apache.spark.storage.StorageLevel
* the features computed for this user.
* @param productFeatures RDD of tuples where each tuple represents the productId
* and the features computed for this product.
+ * @since 0.8.0
*/
class MatrixFactorizationModel(
val rank: Int,
@@ -73,7 +74,9 @@ class MatrixFactorizationModel(
}
}
- /** Predict the rating of one user for one product. */
+ /** Predict the rating of one user for one product.
+ * @since 0.8.0
+ */
def predict(user: Int, product: Int): Double = {
val userVector = userFeatures.lookup(user).head
val productVector = productFeatures.lookup(product).head
@@ -111,6 +114,7 @@ class MatrixFactorizationModel(
*
* @param usersProducts RDD of (user, product) pairs.
* @return RDD of Ratings.
+ * @since 0.9.0
*/
def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
// Previously the partitions of ratings are only based on the given products.
@@ -142,6 +146,7 @@ class MatrixFactorizationModel(
/**
* Java-friendly version of [[MatrixFactorizationModel.predict]].
+ * @since 1.2.0
*/
def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = {
predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD()
@@ -157,6 +162,7 @@ class MatrixFactorizationModel(
* by score, decreasing. The first returned is the one predicted to be most strongly
* recommended to the user. The score is an opaque value that indicates how strongly
* recommended the product is.
+ * @since 1.1.0
*/
def recommendProducts(user: Int, num: Int): Array[Rating] =
MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num)
@@ -173,6 +179,7 @@ class MatrixFactorizationModel(
* by score, decreasing. The first returned is the one predicted to be most strongly
* recommended to the product. The score is an opaque value that indicates how strongly
* recommended the user is.
+ * @since 1.1.0
*/
def recommendUsers(product: Int, num: Int): Array[Rating] =
MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num)
@@ -180,6 +187,20 @@ class MatrixFactorizationModel(
protected override val formatVersion: String = "1.0"
+ /**
+ * Save this model to the given path.
+ *
+ * This saves:
+ * - human-readable (JSON) model metadata to path/metadata/
+ * - Parquet formatted data to path/data/
+ *
+ * The model may be loaded using [[Loader.load]].
+ *
+ * @param sc Spark context used to save model data.
+ * @param path Path specifying the directory in which to save this model.
+ * If the directory already exists, this method throws an exception.
+ * @since 1.3.0
+ */
override def save(sc: SparkContext, path: String): Unit = {
MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
}
@@ -191,6 +212,7 @@ class MatrixFactorizationModel(
* @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of
* rating objects which contains the same userId, recommended productID and a "score" in the
* rating field. Semantics of score is same as recommendProducts API
+ * @since 1.4.0
*/
def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = {
MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map {
@@ -208,6 +230,7 @@ class MatrixFactorizationModel(
* @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array
* of rating objects which contains the recommended userId, same productID and a "score" in the
* rating field. Semantics of score is same as recommendUsers API
+ * @since 1.4.0
*/
def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = {
MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map {
@@ -218,6 +241,9 @@ class MatrixFactorizationModel(
}
}
+/**
+ * @since 1.3.0
+ */
object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
import org.apache.spark.mllib.util.Loader._
@@ -292,6 +318,16 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
}
}
+ /**
+ * Load a model from the given path.
+ *
+ * The model should have been saved by [[Saveable.save]].
+ *
+ * @param sc Spark context used for loading model files.
+ * @param path Path specifying the directory to which the model was saved.
+ * @return Model instance
+ * @since 1.3.0
+ */
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
val (loadedClassName, formatVersion, _) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
index 58a50f9c19f14..93a6753efd4d9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
@@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD
* .setBandwidth(3.0)
* val densities = kd.estimate(Array(-1.0, 2.0, 5.0))
* }}}
+ * @since 1.4.0
*/
@Experimental
class KernelDensity extends Serializable {
@@ -51,6 +52,7 @@ class KernelDensity extends Serializable {
/**
* Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`).
+ * @since 1.4.0
*/
def setBandwidth(bandwidth: Double): this.type = {
require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.")
@@ -60,6 +62,7 @@ class KernelDensity extends Serializable {
/**
* Sets the sample to use for density estimation.
+ * @since 1.4.0
*/
def setSample(sample: RDD[Double]): this.type = {
this.sample = sample
@@ -68,6 +71,7 @@ class KernelDensity extends Serializable {
/**
* Sets the sample to use for density estimation (for Java users).
+ * @since 1.4.0
*/
def setSample(sample: JavaRDD[java.lang.Double]): this.type = {
this.sample = sample.rdd.asInstanceOf[RDD[Double]]
@@ -76,6 +80,7 @@ class KernelDensity extends Serializable {
/**
* Estimates probability density function at the given array of points.
+ * @since 1.4.0
*/
def estimate(points: Array[Double]): Array[Double] = {
val sample = this.sample
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index d321cc554c1cc..62da9f2ef22a3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -33,6 +33,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector}
* Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
* Zero elements (including explicit zero values) are skipped when calling add(),
* to have time complexity O(nnz) instead of O(n) for each column.
+ * @since 1.1.0
*/
@DeveloperApi
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
@@ -52,6 +53,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*
* @param sample The sample in dense/sparse vector format to be added into this summarizer.
* @return This MultivariateOnlineSummarizer object.
+ * @since 1.1.0
*/
def add(sample: Vector): this.type = {
if (n == 0) {
@@ -107,6 +109,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*
* @param other The other MultivariateOnlineSummarizer to be merged.
* @return This MultivariateOnlineSummarizer object.
+ * @since 1.1.0
*/
def merge(other: MultivariateOnlineSummarizer): this.type = {
if (this.totalCnt != 0 && other.totalCnt != 0) {
@@ -149,6 +152,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
this
}
+ /**
+ * @since 1.1.0
+ */
override def mean: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
@@ -161,6 +167,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
Vectors.dense(realMean)
}
+ /**
+ * @since 1.1.0
+ */
override def variance: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
@@ -183,14 +192,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
Vectors.dense(realVariance)
}
+ /**
+ * @since 1.1.0
+ */
override def count: Long = totalCnt
+ /**
+ * @since 1.1.0
+ */
override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
Vectors.dense(nnz)
}
+ /**
+ * @since 1.1.0
+ */
override def max: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
@@ -202,6 +220,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
Vectors.dense(currMax)
}
+ /**
+ * @since 1.1.0
+ */
override def min: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
@@ -213,6 +234,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
Vectors.dense(currMin)
}
+ /**
+ * @since 1.2.0
+ */
override def normL2: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
@@ -227,6 +251,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
Vectors.dense(realMagnitude)
}
+ /**
+ * @since 1.2.0
+ */
override def normL1: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
index 6a364c93284af..3bb49f12289e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
@@ -21,46 +21,55 @@ import org.apache.spark.mllib.linalg.Vector
/**
* Trait for multivariate statistical summary of a data matrix.
+ * @since 1.0.0
*/
trait MultivariateStatisticalSummary {
/**
* Sample mean vector.
+ * @since 1.0.0
*/
def mean: Vector
/**
* Sample variance vector. Should return a zero vector if the sample size is 1.
+ * @since 1.0.0
*/
def variance: Vector
/**
* Sample size.
+ * @since 1.0.0
*/
def count: Long
/**
* Number of nonzero elements (including explicitly presented zero values) in each column.
+ * @since 1.0.0
*/
def numNonzeros: Vector
/**
* Maximum value of each column.
+ * @since 1.0.0
*/
def max: Vector
/**
* Minimum value of each column.
+ * @since 1.0.0
*/
def min: Vector
/**
* Euclidean magnitude of each column
+ * @since 1.2.0
*/
def normL2: Vector
/**
* L1 norm of each column
+ * @since 1.2.0
*/
def normL1: Vector
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
index 90332028cfb3a..f84502919e381 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
@@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD
/**
* :: Experimental ::
* API for statistical functions in MLlib.
+ * @since 1.1.0
*/
@Experimental
object Statistics {
@@ -41,6 +42,7 @@ object Statistics {
*
* @param X an RDD[Vector] for which column-wise summary statistics are to be computed.
* @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics.
+ * @since 1.1.0
*/
def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = {
new RowMatrix(X).computeColumnSummaryStatistics()
@@ -52,6 +54,7 @@ object Statistics {
*
* @param X an RDD[Vector] for which the correlation matrix is to be computed.
* @return Pearson correlation matrix comparing columns in X.
+ * @since 1.1.0
*/
def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X)
@@ -68,6 +71,7 @@ object Statistics {
* @param method String specifying the method to use for computing correlation.
* Supported: `pearson` (default), `spearman`
* @return Correlation matrix comparing columns in X.
+ * @since 1.1.0
*/
def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method)
@@ -81,10 +85,14 @@ object Statistics {
* @param x RDD[Double] of the same cardinality as y.
* @param y RDD[Double] of the same cardinality as x.
* @return A Double containing the Pearson correlation between the two input RDD[Double]s
+ * @since 1.1.0
*/
def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
- /** Java-friendly version of [[corr()]] */
+ /**
+ * Java-friendly version of [[corr()]]
+ * @since 1.4.1
+ */
def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])
@@ -101,10 +109,14 @@ object Statistics {
* Supported: `pearson` (default), `spearman`
* @return A Double containing the correlation between the two input RDD[Double]s using the
* specified method.
+ * @since 1.1.0
*/
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
- /** Java-friendly version of [[corr()]] */
+ /**
+ * Java-friendly version of [[corr()]]
+ * @since 1.4.1
+ */
def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)
@@ -121,6 +133,7 @@ object Statistics {
* `expected` is rescaled if the `expected` sum differs from the `observed` sum.
* @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
* the method used, and the null hypothesis.
+ * @since 1.1.0
*/
def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = {
ChiSqTest.chiSquared(observed, expected)
@@ -135,6 +148,7 @@ object Statistics {
* @param observed Vector containing the observed categorical counts/relative frequencies.
* @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
* the method used, and the null hypothesis.
+ * @since 1.1.0
*/
def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed)
@@ -145,6 +159,7 @@ object Statistics {
* @param observed The contingency matrix (containing either counts or relative frequencies).
* @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
* the method used, and the null hypothesis.
+ * @since 1.1.0
*/
def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed)
@@ -157,6 +172,7 @@ object Statistics {
* Real-valued features will be treated as categorical for each distinct value.
* @return an array containing the ChiSquaredTestResult for every feature against the label.
* The order of the elements in the returned array reflects the order of input features.
+ * @since 1.1.0
*/
def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = {
ChiSqTest.chiSquaredFeatures(data)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index cf51b24ff777f..9aa7763d7890d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.util.MLUtils
*
* @param mu The mean vector of the distribution
* @param sigma The covariance matrix of the distribution
+ * @since 1.3.0
*/
@DeveloperApi
class MultivariateGaussian (
@@ -60,12 +61,16 @@ class MultivariateGaussian (
*/
private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
- /** Returns density of this multivariate Gaussian at given point, x */
+ /** Returns density of this multivariate Gaussian at given point, x
+ * @since 1.3.0
+ */
def pdf(x: Vector): Double = {
pdf(x.toBreeze)
}
- /** Returns the log-density of this multivariate Gaussian at given point, x */
+ /** Returns the log-density of this multivariate Gaussian at given point, x
+ * @since 1.3.0
+ */
def logpdf(x: Vector): Double = {
logpdf(x.toBreeze)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index a835f96d5d0e3..9ce6faa137c41 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging {
false
}
+ // Prepare periodic checkpointers
+ val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, input.sparkContext)
+ val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, input.sparkContext)
+
timer.stop("init")
logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")
- var data = input
// Initialize tree
timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+ val firstTreeModel = new DecisionTree(treeStrategy).run(input)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
+ predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())
// Note: A model of type regression is used since we require raw prediction
@@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging {
var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
+ if (validate) validatePredErrorCheckpointer.update(validatePredError)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1
- // pseudo-residual for second iteration
- data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
-
var m = 1
- while (m < numIterations) {
+ var doneLearning = false
+ while (m < numIterations && !doneLearning) {
+ // Update data with pseudo-residuals
+ val data = predError.zip(input).map { case ((pred, _), point) =>
+ LabeledPoint(-loss.gradient(pred, point.label), point.features)
+ }
+
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val model = new DecisionTree(treeStrategy).run(data)
timer.stop(s"building tree $m")
- // Create partial model
+ // Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
- // Note: A model of type regression is used since we require raw prediction
- val partialModel = new GradientBoostedTreesModel(
- Regression, baseLearners.slice(0, m + 1),
- baseLearnerWeights.slice(0, m + 1))
predError = GradientBoostedTreesModel.updatePredictionError(
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
+ predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())
if (validate) {
@@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging {
validatePredError = GradientBoostedTreesModel.updatePredictionError(
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
+ validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol) {
- return new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo,
- baseLearners.slice(0, bestM),
- baseLearnerWeights.slice(0, bestM))
+ doneLearning = true
} else if (currentValidateError < bestValidateError) {
- bestValidateError = currentValidateError
- bestM = m + 1
+ bestValidateError = currentValidateError
+ bestM = m + 1
}
}
- // Update data with pseudo-residuals
- data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
m += 1
}
@@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
+ predErrorCheckpointer.deleteAllCheckpoints()
+ validatePredErrorCheckpointer.deleteAllCheckpoints()
if (persistedInput) input.unpersist()
if (validate) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 2d6b01524ff3d..9fd30c9b56319 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* learning rate should be between in the interval (0, 1]
* @param validationTol Useful when runWithValidation is used. If the error rate on the
* validation input between two iterations is less than the validationTol
- * then stop. Ignored when [[run]] is used.
+ * then stop. Ignored when
+ * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
*/
@Experimental
case class BoostingStrategy(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 380291ac22bd3..9fe264656ede7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -128,9 +128,13 @@ private[spark] object DecisionTreeMetadata extends Logging {
// based on the number of training examples.
if (strategy.categoricalFeaturesInfo.nonEmpty) {
val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
+ val maxCategory =
+ strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
require(maxCategoriesPerFeature <= maxPossibleBins,
- s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
- s"in categorical features (= $maxCategoriesPerFeature)")
+ s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
+ s"number of values in each categorical feature, but categorical feature $maxCategory " +
+ s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " +
+ "features with a large number of values, or add more training examples.")
}
val unorderedFeatures = new mutable.HashSet[Int]()
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index b48f190f599a2..d272a42c8576f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -19,6 +19,7 @@
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.Arrays;
import scala.Tuple2;
@@ -59,7 +60,10 @@ public void tearDown() {
@Test
public void localLDAModel() {
- LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics());
+ Matrix topics = LDASuite$.MODULE$.tinyTopics();
+ double[] topicConcentration = new double[topics.numRows()];
+ Arrays.fill(topicConcentration, 1.0D / topics.numRows());
+ LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D);
// Check: basic parameters
assertEquals(model.k(), tinyK);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 82c345491bb3c..a7bc77965fefd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.Utils
/**
@@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ sc.setCheckpointDir(path)
+
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+ val gbt = new GBTClassifier()
+ .setMaxDepth(2)
+ .setLossType("logistic")
+ .setMaxIter(5)
+ .setStepSize(0.1)
+ .setCheckpointInterval(2)
+ val model = gbt.fit(df)
+
+ sc.checkpointDir = None
+ Utils.deleteRecursively(tempDir)
+ }
+
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/*
test("runWithValidation stops early and performs better on a validation dataset") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 1b6b69c7dc71e..ab711c8e4b215 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestClassifier]].
@@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
- Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
ParamsSuite.checkParams(model)
}
@@ -167,9 +167,19 @@ private object RandomForestClassifierSuite {
val newModel = rf.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestClassificationModel.fromOld(
- oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
+ oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures,
+ numClasses)
TreeTests.checkEqual(oldModelAsNew, newModel)
assert(newModel.hasParent)
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
+ assert(newModel.numClasses == numClasses)
+ val results = newModel.transform(newData)
+ results.select("rawPrediction", "prediction").collect().foreach {
+ case Row(raw: Vector, prediction: Double) => {
+ assert(raw.size == numClasses)
+ val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
+ assert(predFromRaw == prediction)
+ }
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 65846a846b7b4..321eeb843941c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
- assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
- assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
}
test("input column without ML attribute") {
@@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
- assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
- assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
index c4b45aee06384..436e66bab09b0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -18,12 +18,17 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
class RFormulaParserSuite extends SparkFunSuite {
- private def checkParse(formula: String, label: String, terms: Seq[String]) {
- val parsed = RFormulaParser.parse(formula)
- assert(parsed.label == label)
- assert(parsed.terms == terms)
+ private def checkParse(
+ formula: String,
+ label: String,
+ terms: Seq[String],
+ schema: StructType = null) {
+ val resolved = RFormulaParser.parse(formula).resolve(schema)
+ assert(resolved.label == label)
+ assert(resolved.terms == terms)
}
test("parse simple formulas") {
@@ -32,4 +37,46 @@ class RFormulaParserSuite extends SparkFunSuite {
checkParse("y ~ ._foo ", "y", Seq("._foo"))
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
}
+
+ test("parse dot") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "long", false)
+ .add("c", "string", true)
+ checkParse("a ~ .", "a", Seq("b", "c"), schema)
+ }
+
+ test("parse deletion") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "long", false)
+ .add("c", "string", true)
+ checkParse("a ~ c - b", "a", Seq("c"), schema)
+ }
+
+ test("parse additions and deletions in order") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "long", false)
+ .add("c", "string", true)
+ checkParse("a ~ . - b + . - c", "a", Seq("b"), schema)
+ }
+
+ test("dot ignores complex column types") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "tinyint", false)
+ .add("c", "map", true)
+ checkParse("a ~ .", "a", Seq("b"), schema)
+ }
+
+ test("parse intercept") {
+ assert(RFormulaParser.parse("a ~ b").hasIntercept)
+ assert(RFormulaParser.parse("a ~ b + 1").hasIntercept)
+ assert(RFormulaParser.parse("a ~ b - 0").hasIntercept)
+ assert(RFormulaParser.parse("a ~ b - 1 + 1").hasIntercept)
+ assert(!RFormulaParser.parse("a ~ b + 0").hasIntercept)
+ assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept)
+ assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 8148c553e9051..6aed3243afce8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
}
+
+ test("attribute generation") {
+ val formula = new RFormula().setFormula("id ~ a + b")
+ val original = sqlContext.createDataFrame(
+ Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array(
+ new BinaryAttribute(Some("a__bar"), Some(1)),
+ new BinaryAttribute(Some("a__foo"), Some(2)),
+ new NumericAttribute(Some("b"), Some(3))))
+ assert(attrs === expectedAttrs)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 9682edcd9ba84..dbdce0c9dea54 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees =>
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.Utils
/**
@@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(predictions.min() < -1)
}
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ sc.setCheckpointDir(path)
+
+ val df = sqlContext.createDataFrame(data)
+ val gbt = new GBTRegressor()
+ .setMaxDepth(2)
+ .setMaxIter(5)
+ .setStepSize(0.1)
+ .setCheckpointInterval(2)
+ val model = gbt.fit(df)
+
+ sc.checkpointDir = None
+ Utils.deleteRecursively(tempDir)
+ }
+
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/*
test("runWithValidation stops early and performs better on a validation dataset") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
new file mode 100644
index 0000000000000..66e4b170bae80
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+import org.apache.spark.sql.{DataFrame, Row}
+
+class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+ private val schema = StructType(
+ Array(
+ StructField("label", DoubleType),
+ StructField("features", DoubleType),
+ StructField("weight", DoubleType)))
+
+ private val predictionSchema = StructType(Array(StructField("features", DoubleType)))
+
+ private def generateIsotonicInput(labels: Seq[Double]): DataFrame = {
+ val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d))
+ val parallelData = sc.parallelize(data)
+
+ sqlContext.createDataFrame(parallelData, schema)
+ }
+
+ private def generatePredictionInput(features: Seq[Double]): DataFrame = {
+ val data = Seq.tabulate(features.size)(i => Row(features(i)))
+
+ val parallelData = sc.parallelize(data)
+ sqlContext.createDataFrame(parallelData, predictionSchema)
+ }
+
+ test("isotonic regression predictions") {
+ val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18))
+ val trainer = new IsotonicRegression().setIsotonicParam(true)
+
+ val model = trainer.fit(dataset)
+
+ val predictions = model
+ .transform(dataset)
+ .select("prediction").map {
+ case Row(pred) => pred
+ }.collect()
+
+ assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18))
+
+ assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8))
+ assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0))
+ assert(model.parentModel.isotonic)
+ }
+
+ test("antitonic regression predictions") {
+ val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1))
+ val trainer = new IsotonicRegression().setIsotonicParam(false)
+
+ val model = trainer.fit(dataset)
+ val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0))
+
+ val predictions = model
+ .transform(features)
+ .select("prediction").map {
+ case Row(pred) => pred
+ }.collect()
+
+ assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1))
+ }
+
+ test("params validation") {
+ val dataset = generateIsotonicInput(Seq(1, 2, 3))
+ val ir = new IsotonicRegression
+ ParamsSuite.checkParams(ir)
+ val model = ir.fit(dataset)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("default params") {
+ val dataset = generateIsotonicInput(Seq(1, 2, 3))
+ val ir = new IsotonicRegression()
+ assert(ir.getLabelCol === "label")
+ assert(ir.getFeaturesCol === "features")
+ assert(ir.getWeightCol === "weight")
+ assert(ir.getPredictionCol === "prediction")
+ assert(ir.getIsotonicParam === true)
+
+ val model = ir.fit(dataset)
+ model.transform(dataset)
+ .select("label", "features", "prediction", "weight")
+ .collect()
+
+ assert(model.getLabelCol === "label")
+ assert(model.getFeaturesCol === "features")
+ assert(model.getWeightCol === "weight")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.getIsotonicParam === true)
+ assert(model.hasParent)
+ }
+
+ test("set parameters") {
+ val isotonicRegression = new IsotonicRegression()
+ .setIsotonicParam(false)
+ .setWeightParam("w")
+ .setFeaturesCol("f")
+ .setLabelCol("l")
+ .setPredictionCol("p")
+
+ assert(isotonicRegression.getIsotonicParam === false)
+ assert(isotonicRegression.getWeightCol === "w")
+ assert(isotonicRegression.getFeaturesCol === "f")
+ assert(isotonicRegression.getLabelCol === "l")
+ assert(isotonicRegression.getPredictionCol === "p")
+ }
+
+ test("missing column") {
+ val dataset = generateIsotonicInput(Seq(1, 2, 3))
+
+ intercept[IllegalArgumentException] {
+ new IsotonicRegression().setWeightParam("w").fit(dataset)
+ }
+
+ intercept[IllegalArgumentException] {
+ new IsotonicRegression().setFeaturesCol("f").fit(dataset)
+ }
+
+ intercept[IllegalArgumentException] {
+ new IsotonicRegression().setLabelCol("l").fit(dataset)
+ }
+
+ intercept[IllegalArgumentException] {
+ new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index fd653296c9d97..d7b291d5a6330 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.dstream.DStream
-import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 30000
+ var ssc: StreamingContext = _
+
+ override def afterFunction() {
+ super.afterFunction()
+ if (ssc != null) {
+ ssc.stop()
+ }
+ }
+
// Test if we can accurately learn B for Y = logistic(BX) on streaming data
test("parameter accuracy") {
@@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
}
// apply model training to input stream
- val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.count()
})
@@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
// apply model training to input stream, storing the intermediate results
// (we add a count to ensure the result is a DStream)
- val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B)))
inputDStream.count()
@@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
}
// apply model predictions to test stream
- val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
@@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
}
// train and predict
- val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
@@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
.setNumIterations(10)
val numBatches = 10
val emptyInput = Seq.empty[Seq[LabeledPoint]]
- val ssc = setupStreams(emptyInput,
+ ssc = setupStreams(emptyInput,
(inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 376a87f0511b4..c43e1e575c09c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM}
+import breeze.linalg.{DenseMatrix => BDM, max, argmax}
import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx.Edge
@@ -31,7 +31,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
import LDASuite._
test("LocalLDAModel") {
- val model = new LocalLDAModel(tinyTopics)
+ val model = new LocalLDAModel(tinyTopics,
+ Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
// Check: basic parameters
assert(model.k === tinyK)
@@ -82,21 +83,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.topicsMatrix === localModel.topicsMatrix)
// Check: topic summaries
- // The odd decimal formatting and sorting is a hack to do a robust comparison.
- val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) =>
- // cut values to 3 digits after the decimal place
- terms.zip(termWeights).map { case (term, weight) =>
- ("%.3f".format(weight).toDouble, term.toInt)
- }
- }.sortBy(_.mkString(""))
- val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
- // cut values to 3 digits after the decimal place
- terms.zip(termWeights).map { case (term, weight) =>
- ("%.3f".format(weight).toDouble, term.toInt)
- }
- }.sortBy(_.mkString(""))
- roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) =>
- assert(t1 === t2)
+ val topicSummary = model.describeTopics().map { case (terms, termWeights) =>
+ Vectors.sparse(tinyVocabSize, terms, termWeights)
+ }.sortBy(_.toString)
+ val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
+ Vectors.sparse(tinyVocabSize, terms, termWeights)
+ }.sortBy(_.toString)
+ topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) =>
+ assert(topics ~== topicsLocal absTol 0.01)
}
// Check: per-doc topic distributions
@@ -196,10 +190,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
// verify the result, Note this generate the identical result as
// [[https://github.com/Blei-Lab/onlineldavb]]
- val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
- val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
- assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
- assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
+ val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t)
+ val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t)
+ val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950)
+ val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050)
+ assert(topic1 ~== expectedTopic1 absTol 0.01)
+ assert(topic2 ~== expectedTopic2 absTol 0.01)
}
test("OnlineLDAOptimizer with toy data") {
@@ -235,6 +231,114 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("LocalLDAModel logPerplexity") {
+ val k = 2
+ val vocabSize = 6
+ val alpha = 0.01
+ val eta = 0.01
+ val gammaShape = 100
+ // obtained from LDA model trained in gensim, see below
+ val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
+ 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
+ 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
+
+ def toydata: Array[(Long, Vector)] = Array(
+ Vectors.sparse(6, Array(0, 1), Array(1, 1)),
+ Vectors.sparse(6, Array(1, 2), Array(1, 1)),
+ Vectors.sparse(6, Array(0, 2), Array(1, 1)),
+ Vectors.sparse(6, Array(3, 4), Array(1, 1)),
+ Vectors.sparse(6, Array(3, 5), Array(1, 1)),
+ Vectors.sparse(6, Array(4, 5), Array(1, 1))
+ ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+ val docs = sc.parallelize(toydata)
+
+
+ val ldaModel: LocalLDAModel = new LocalLDAModel(
+ topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
+
+ /* Verify results using gensim:
+ import numpy as np
+ from gensim import models
+ corpus = [
+ [(0, 1.0), (1, 1.0)],
+ [(1, 1.0), (2, 1.0)],
+ [(0, 1.0), (2, 1.0)],
+ [(3, 1.0), (4, 1.0)],
+ [(3, 1.0), (5, 1.0)],
+ [(4, 1.0), (5, 1.0)]]
+ np.random.seed(2345)
+ lda = models.ldamodel.LdaModel(
+ corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
+ decay=0.51, offset=1024)
+ print(lda.log_perplexity(corpus))
+ > -3.69051285096
+ */
+
+ assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D)
+ }
+
+ test("LocalLDAModel predict") {
+ val k = 2
+ val vocabSize = 6
+ val alpha = 0.01
+ val eta = 0.01
+ val gammaShape = 100
+ // obtained from LDA model trained in gensim, see below
+ val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
+ 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
+ 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
+
+ def toydata: Array[(Long, Vector)] = Array(
+ Vectors.sparse(6, Array(0, 1), Array(1, 1)),
+ Vectors.sparse(6, Array(1, 2), Array(1, 1)),
+ Vectors.sparse(6, Array(0, 2), Array(1, 1)),
+ Vectors.sparse(6, Array(3, 4), Array(1, 1)),
+ Vectors.sparse(6, Array(3, 5), Array(1, 1)),
+ Vectors.sparse(6, Array(4, 5), Array(1, 1))
+ ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+ val docs = sc.parallelize(toydata)
+
+ val ldaModel: LocalLDAModel = new LocalLDAModel(
+ topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
+
+ /* Verify results using gensim:
+ import numpy as np
+ from gensim import models
+ corpus = [
+ [(0, 1.0), (1, 1.0)],
+ [(1, 1.0), (2, 1.0)],
+ [(0, 1.0), (2, 1.0)],
+ [(3, 1.0), (4, 1.0)],
+ [(3, 1.0), (5, 1.0)],
+ [(4, 1.0), (5, 1.0)]]
+ np.random.seed(2345)
+ lda = models.ldamodel.LdaModel(
+ corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
+ decay=0.51, offset=1024)
+ print(list(lda.get_document_topics(corpus)))
+ > [[(0, 0.99504950495049516)], [(0, 0.99504950495049516)],
+ > [(0, 0.99504950495049516)], [(1, 0.99504950495049516)],
+ > [(1, 0.99504950495049516)], [(1, 0.99504950495049516)]]
+ */
+
+ val expectedPredictions = List(
+ (0, 0.99504), (0, 0.99504),
+ (0, 0.99504), (1, 0.99504),
+ (1, 0.99504), (1, 0.99504))
+
+ val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) =>
+ // convert results to expectedPredictions format, which only has highest probability topic
+ val topicsBz = topics.toBreeze.toDenseVector
+ (id, (argmax(topicsBz), max(topicsBz)))
+ }.sortByKey()
+ .values
+ .collect()
+
+ expectedPredictions.zip(actualPredictions).forall { case (expected, actual) =>
+ expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)
+ }
+ }
+
test("OnlineLDAOptimizer with asymmetric prior") {
def toydata: Array[(Long, Vector)] = Array(
Vectors.sparse(6, Array(0, 1), Array(1, 1)),
@@ -287,7 +391,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
test("model save/load") {
// Test for LocalLDAModel.
- val localModel = new LocalLDAModel(tinyTopics)
+ val localModel = new LocalLDAModel(tinyTopics,
+ Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D)
val tempDir1 = Utils.createTempDir()
val path1 = tempDir1.toURI.toString
@@ -313,6 +418,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
assert(samelocalModel.k === localModel.k)
assert(samelocalModel.vocabSize === localModel.vocabSize)
+ assert(samelocalModel.docConcentration === localModel.docConcentration)
+ assert(samelocalModel.topicConcentration === localModel.topicConcentration)
+ assert(samelocalModel.gammaShape === localModel.gammaShape)
val sameDistributedModel = DistributedLDAModel.load(sc, path2)
assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
@@ -321,6 +429,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
+ assert(distributedModel.gammaShape === sameDistributedModel.gammaShape)
assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)
val graph = distributedModel.graph
@@ -339,6 +448,46 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("EMLDAOptimizer with empty docs") {
+ val vocabSize = 6
+ val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty))
+ val emptyDocs = emptyDocsArray
+ .zipWithIndex.map { case (wordCounts, docId) =>
+ (docId.toLong, wordCounts)
+ }
+ val distributedEmptyDocs = sc.parallelize(emptyDocs, 2)
+
+ val op = new EMLDAOptimizer()
+ val lda = new LDA()
+ .setK(3)
+ .setMaxIterations(5)
+ .setSeed(12345)
+ .setOptimizer(op)
+
+ val model = lda.run(distributedEmptyDocs)
+ assert(model.vocabSize === vocabSize)
+ }
+
+ test("OnlineLDAOptimizer with empty docs") {
+ val vocabSize = 6
+ val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty))
+ val emptyDocs = emptyDocsArray
+ .zipWithIndex.map { case (wordCounts, docId) =>
+ (docId.toLong, wordCounts)
+ }
+ val distributedEmptyDocs = sc.parallelize(emptyDocs, 2)
+
+ val op = new OnlineLDAOptimizer()
+ val lda = new LDA()
+ .setK(3)
+ .setMaxIterations(5)
+ .setSeed(12345)
+ .setOptimizer(op)
+
+ val model = lda.run(distributedEmptyDocs)
+ assert(model.vocabSize === vocabSize)
+ }
+
}
private[clustering] object LDASuite {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
index ac01622b8a089..3645d29dccdb2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.random.XORShiftRandom
@@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
override def maxWaitTimeMillis: Int = 30000
+ var ssc: StreamingContext = _
+
+ override def afterFunction() {
+ super.afterFunction()
+ if (ssc != null) {
+ ssc.stop()
+ }
+ }
+
test("accuracy for single center and equivalence to grand average") {
// set parameters
val numBatches = 10
@@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
// setup and run the model training
- val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
model.trainOn(inputDStream)
inputDStream.count()
})
@@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
// setup and run the model training
- val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
kMeans.trainOn(inputDStream)
inputDStream.count()
})
@@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
// setup and run the model training
- val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
kMeans.trainOn(inputDStream)
inputDStream.count()
})
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index 9f107c89f6d80..6dd2dc926acc5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(sequences, 2).cache()
- def compareResult(
- expectedValue: Array[(Array[Int], Long)],
- actualValue: Array[(Array[Int], Long)]): Boolean = {
- expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
- actualValue.map(x => (x._1.toSeq, x._2)).toSet
- }
-
val prefixspan = new PrefixSpan()
.setMinSupport(0.33)
.setMaxPatternLength(50)
@@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
- assert(compareResult(expectedValue1, result1.collect()))
+ assert(compareResults(expectedValue1, result1.collect()))
prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
val result2 = prefixspan.run(rdd)
@@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4), 4L),
(Array(5), 3L)
)
- assert(compareResult(expectedValue2, result2.collect()))
+ assert(compareResults(expectedValue2, result2.collect()))
prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
val result3 = prefixspan.run(rdd)
@@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
- assert(compareResult(expectedValue3, result3.collect()))
+ assert(compareResults(expectedValue3, result3.collect()))
+ }
+
+ private def compareResults(
+ expectedValue: Array[(Array[Int], Long)],
+ actualValue: Array[(Array[Int], Long)]): Boolean = {
+ expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
+ actualValue.map(x => (x._1.toSeq, x._2)).toSet
}
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
index d34888af2d73b..e331c75989187 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
import PeriodicGraphCheckpointerSuite._
- // TODO: Do I need to call count() on the graphs' RDDs?
-
test("Persisting") {
var graphsToCheck = Seq.empty[GraphToCheck]
val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer(graph1, 10)
+ val checkpointer =
+ new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
checkPersistence(graphsToCheck, 1)
var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
- checkpointer.updateGraph(graph)
+ checkpointer.update(graph)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
checkPersistence(graphsToCheck, iteration)
iteration += 1
@@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
var graphsToCheck = Seq.empty[GraphToCheck]
sc.setCheckpointDir(path)
val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval)
+ val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
+ checkpointInterval, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
graph1.edges.count()
graph1.vertices.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
@@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
- checkpointer.updateGraph(graph)
+ checkpointer.update(graph)
graph.vertices.count()
graph.edges.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
@@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite {
} else {
// Graph should never be checkpointed
assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
- assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files")
+ assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files")
}
} catch {
case e: AssertionError =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
new file mode 100644
index 0000000000000..b2a459a68b5fa
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.{SparkContext, SparkFunSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+
+class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ import PeriodicRDDCheckpointerSuite._
+
+ test("Persisting") {
+ var rddsToCheck = Seq.empty[RDDToCheck]
+
+ val rdd1 = createRDD(sc)
+ val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext)
+ checkpointer.update(rdd1)
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
+ checkPersistence(rddsToCheck, 1)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val rdd = createRDD(sc)
+ checkpointer.update(rdd)
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
+ checkPersistence(rddsToCheck, iteration)
+ iteration += 1
+ }
+ }
+
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val checkpointInterval = 2
+ var rddsToCheck = Seq.empty[RDDToCheck]
+ sc.setCheckpointDir(path)
+ val rdd1 = createRDD(sc)
+ val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext)
+ checkpointer.update(rdd1)
+ rdd1.count()
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
+ checkCheckpoint(rddsToCheck, 1, checkpointInterval)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val rdd = createRDD(sc)
+ checkpointer.update(rdd)
+ rdd.count()
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
+ checkCheckpoint(rddsToCheck, iteration, checkpointInterval)
+ iteration += 1
+ }
+
+ checkpointer.deleteAllCheckpoints()
+ rddsToCheck.foreach { rdd =>
+ confirmCheckpointRemoved(rdd.rdd)
+ }
+
+ Utils.deleteRecursively(tempDir)
+ }
+}
+
+private object PeriodicRDDCheckpointerSuite {
+
+ case class RDDToCheck(rdd: RDD[Double], gIndex: Int)
+
+ def createRDD(sc: SparkContext): RDD[Double] = {
+ sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0))
+ }
+
+ def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = {
+ rdds.foreach { g =>
+ checkPersistence(g.rdd, g.gIndex, iteration)
+ }
+ }
+
+ /**
+ * Check storage level of rdd.
+ * @param gIndex Index of rdd in order inserted into checkpointer (from 1).
+ * @param iteration Total number of rdds inserted into checkpointer.
+ */
+ def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = {
+ try {
+ if (gIndex + 2 < iteration) {
+ assert(rdd.getStorageLevel == StorageLevel.NONE)
+ } else {
+ assert(rdd.getStorageLevel != StorageLevel.NONE)
+ }
+ } catch {
+ case _: AssertionError =>
+ throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n")
+ }
+ }
+
+ def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = {
+ rdds.reverse.foreach { g =>
+ checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval)
+ }
+ }
+
+ def confirmCheckpointRemoved(rdd: RDD[_]): Unit = {
+ // Note: We cannot check rdd.isCheckpointed since that value is never updated.
+ // Instead, we check for the presence of the checkpoint files.
+ // This test should continue to work even after this rdd.isCheckpointed issue
+ // is fixed (though it can then be simplified and not look for the files).
+ val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration)
+ rdd.getCheckpointFile.foreach { checkpointFile =>
+ assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed")
+ }
+ }
+
+ /**
+ * Check checkpointed status of rdd.
+ * @param gIndex Index of rdd in order inserted into checkpointer (from 1).
+ * @param iteration Total number of rdds inserted into checkpointer.
+ */
+ def checkCheckpoint(
+ rdd: RDD[_],
+ gIndex: Int,
+ iteration: Int,
+ checkpointInterval: Int): Unit = {
+ try {
+ if (gIndex % checkpointInterval == 0) {
+ // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd)
+ // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint.
+ if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
+ assert(rdd.isCheckpointed, "RDD should be checkpointed")
+ assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files")
+ } else {
+ confirmCheckpointRemoved(rdd)
+ }
+ } else {
+ // RDD should never be checkpointed
+ assert(!rdd.isCheckpointed, "RDD should never have been checkpointed")
+ assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files")
+ }
+ } catch {
+ case e: AssertionError =>
+ throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t checkpointInterval = $checkpointInterval\n" +
+ s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" +
+ s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" +
+ s" AssertionError message: ${e.getMessage}")
+ }
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 03be4119bdaca..1c37ea5123e82 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -57,6 +57,21 @@ class VectorsSuite extends SparkFunSuite with Logging {
assert(vec.values === values)
}
+ test("sparse vector construction with mismatched indices/values array") {
+ intercept[IllegalArgumentException] {
+ Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0, 7.0, 9.0))
+ }
+ intercept[IllegalArgumentException] {
+ Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0))
+ }
+ }
+
+ test("sparse vector construction with too many indices vs size") {
+ intercept[IllegalArgumentException] {
+ Vectors.sparse(3, Array(1, 2, 3, 4), Array(3.0, 5.0, 7.0, 9.0))
+ }
+ }
+
test("dense to array") {
val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
assert(vec.toArray.eq(arr))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index b6cb53d0c743e..283ffec1d49d7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
import scala.util.Random
+import breeze.numerics.abs
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
import org.apache.spark.SparkFunSuite
@@ -238,6 +239,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("QR Decomposition") {
+ for (mat <- Seq(denseMat, sparseMat)) {
+ val result = mat.tallSkinnyQR(true)
+ val expected = breeze.linalg.qr.reduced(mat.toBreeze())
+ val calcQ = result.Q
+ val calcR = result.R
+ assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze())))
+ assert(closeToZero(abs(expected.r) - abs(calcR.toBreeze.asInstanceOf[BDM[Double]])))
+ assert(closeToZero(calcQ.multiply(calcR).toBreeze - mat.toBreeze()))
+ // Decomposition without computing Q
+ val rOnly = mat.tallSkinnyQR(computeQ = false)
+ assert(rOnly.Q == null)
+ assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]])))
+ }
+ }
}
class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index a2a4c5f6b8b70..34c07ed170816 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream
-import org.apache.spark.streaming.TestSuiteBase
class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 20000
+ var ssc: StreamingContext = _
+
+ override def afterFunction() {
+ super.afterFunction()
+ if (ssc != null) {
+ ssc.stop()
+ }
+ }
+
// Assert that two values are equal within tolerance epsilon
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
def errorMessage = v1.toString + " did not equal " + v2.toString
@@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
}
// apply model training to input stream
- val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.count()
})
@@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
// apply model training to input stream, storing the intermediate results
// (we add a count to ensure the result is a DStream)
- val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0)))
inputDStream.count()
@@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
}
// apply model predictions to test stream
- val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
// collect the output as (true, estimated) tuples
@@ -156,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
}
// train and predict
- val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
@@ -177,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
val numBatches = 10
val nPoints = 100
val emptyInput = Seq.empty[Seq[LabeledPoint]]
- val ssc = setupStreams(emptyInput,
+ ssc = setupStreams(emptyInput,
(inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 2521b3342181a..6fc9e8df621df 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
val algos = Array(Regression, Regression, Classification)
val losses = Array(SquaredError, AbsoluteError, LogLoss)
- (algos zip losses) map {
- case (algo, loss) => {
- val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
- val gbtValidate = new GradientBoostedTrees(boostingStrategy)
- .runWithValidation(trainRdd, validateRdd)
- val numTrees = gbtValidate.numTrees
- assert(numTrees !== numIterations)
-
- // Test that it performs better on the validation dataset.
- val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
- val (errorWithoutValidation, errorWithValidation) = {
- if (algo == Classification) {
- val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
- (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
- } else {
- (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
- }
- }
- assert(errorWithValidation <= errorWithoutValidation)
-
- // Test that results from evaluateEachIteration comply with runWithValidation.
- // Note that convergenceTol is set to 0.0
- val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
- assert(evaluationArray.length === numIterations)
- assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
- var i = 1
- while (i < numTrees) {
- assert(evaluationArray(i) <= evaluationArray(i - 1))
- i += 1
+ algos.zip(losses).foreach { case (algo, loss) =>
+ val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+ val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+ .runWithValidation(trainRdd, validateRdd)
+ val numTrees = gbtValidate.numTrees
+ assert(numTrees !== numIterations)
+
+ // Test that it performs better on the validation dataset.
+ val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
+ val (errorWithoutValidation, errorWithValidation) = {
+ if (algo == Classification) {
+ val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
+ } else {
+ (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
}
}
+ assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
}
}
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ sc.setCheckpointDir(path)
+
+ val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty, checkpointInterval = 2)
+ val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1)
+
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+ sc.checkpointDir = None
+ Utils.deleteRecursively(tempDir)
+ }
+
}
private object GradientBoostedTreesSuite {
diff --git a/pylintrc b/pylintrc
index 061775960393b..6a675770da69a 100644
--- a/pylintrc
+++ b/pylintrc
@@ -84,7 +84,7 @@ enable=
# If you would like to improve the code quality of pyspark, remove any of these disabled errors
# run ./dev/lint-python and see if the errors raised by pylint can be fixed.
-disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable
+disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable
[REPORTS]
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index 9ef93071d2e77..3b647985801b7 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -350,7 +350,26 @@ def save_global(self, obj, name=None, pack=struct.pack):
if new_override:
d['__new__'] = obj.__new__
- self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
+ self.save(_load_class)
+ self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj)
+ d.pop('__doc__', None)
+ # handle property and staticmethod
+ dd = {}
+ for k, v in d.items():
+ if isinstance(v, property):
+ k = ('property', k)
+ v = (v.fget, v.fset, v.fdel, v.__doc__)
+ elif isinstance(v, staticmethod) and hasattr(v, '__func__'):
+ k = ('staticmethod', k)
+ v = v.__func__
+ elif isinstance(v, classmethod) and hasattr(v, '__func__'):
+ k = ('classmethod', k)
+ v = v.__func__
+ dd[k] = v
+ self.save(dd)
+ self.write(pickle.TUPLE2)
+ self.write(pickle.REDUCE)
+
else:
raise pickle.PicklingError("Can't pickle %r" % obj)
@@ -708,6 +727,23 @@ def _make_skel_func(code, closures, base_globals = None):
None, None, closure)
+def _load_class(cls, d):
+ """
+ Loads additional properties into class `cls`.
+ """
+ for k, v in d.items():
+ if isinstance(k, tuple):
+ typ, k = k
+ if typ == 'property':
+ v = property(*v)
+ elif typ == 'staticmethod':
+ v = staticmethod(v)
+ elif typ == 'classmethod':
+ v = classmethod(v)
+ setattr(cls, k, v)
+ return cls
+
+
"""Constructors for 3rd party libraries
Note: These can never be renamed due to client compatibility issues"""
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 90cd342a6cf7f..60be85e53e2aa 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -52,7 +52,11 @@ def launch_gateway():
script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
if os.environ.get("SPARK_TESTING"):
- submit_args = "--conf spark.ui.enabled=false " + submit_args
+ submit_args = ' '.join([
+ "--conf spark.ui.enabled=false",
+ "--conf spark.buffer.pageSize=4mb",
+ submit_args
+ ])
command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args)
# Start a socket that will be used by PythonGatewayServer to communicate its port to us
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 89117e492846b..5a82bc286d1e8 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
- >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
+ >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
- >>> allclose(model.treeWeights, [1.0, 1.0])
+ >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 86e654dd0779f..015e7a9d4900a 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -525,7 +525,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
A regex based tokenizer that extracts tokens either by using the
provided regex pattern (in Java dialect) to split the text
- (default) or repeatedly matching the regex (if gaps is true).
+ (default) or repeatedly matching the regex (if gaps is false).
Optional parameters also allow filtering tokens using a minimal
length.
It returns an array of strings that can be empty.
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 58ad99d46e23b..900ade248c386 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -152,11 +152,19 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"
return KMeansModel([c.toArray() for c in centers])
-class GaussianMixtureModel(object):
+@inherit_doc
+class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
+
+ """
+ .. note:: Experimental
- """A clustering model derived from the Gaussian Mixture Model method.
+ A clustering model derived from the Gaussian Mixture Model method.
>>> from pyspark.mllib.linalg import Vectors, DenseMatrix
+ >>> from numpy.testing import assert_equal
+ >>> from shutil import rmtree
+ >>> import os, tempfile
+
>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
... 0.9,0.8,0.75,0.935,
... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
@@ -169,6 +177,25 @@ class GaussianMixtureModel(object):
True
>>> labels[4]==labels[5]
True
+
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = GaussianMixtureModel.load(sc, path)
+ >>> assert_equal(model.weights, sameModel.weights)
+ >>> mus, sigmas = list(
+ ... zip(*[(g.mu, g.sigma) for g in model.gaussians]))
+ >>> sameMus, sameSigmas = list(
+ ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians]))
+ >>> mus == sameMus
+ True
+ >>> sigmas == sameSigmas
+ True
+ >>> from shutil import rmtree
+ >>> try:
+ ... rmtree(path)
+ ... except OSError:
+ ... pass
+
>>> data = array([-5.1971, -2.5359, -3.8220,
... -5.2211, -5.0602, 4.7118,
... 6.8989, 3.4592, 4.6322,
@@ -182,25 +209,15 @@ class GaussianMixtureModel(object):
True
>>> labels[3]==labels[4]
True
- >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1))
- >>> im = GaussianMixtureModel([0.5, 0.5],
- ... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])),
- ... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))])
- >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im)
"""
- def __init__(self, weights, gaussians):
- self._weights = weights
- self._gaussians = gaussians
- self._k = len(self._weights)
-
@property
def weights(self):
"""
Weights for each Gaussian distribution in the mixture, where weights[i] is
the weight for Gaussian i, and weights.sum == 1.
"""
- return self._weights
+ return array(self.call("weights"))
@property
def gaussians(self):
@@ -208,12 +225,14 @@ def gaussians(self):
Array of MultivariateGaussian where gaussians[i] represents
the Multivariate Gaussian (Normal) Distribution for Gaussian i.
"""
- return self._gaussians
+ return [
+ MultivariateGaussian(gaussian[0], gaussian[1])
+ for gaussian in zip(*self.call("gaussians"))]
@property
def k(self):
"""Number of gaussians in mixture."""
- return self._k
+ return len(self.weights)
def predict(self, x):
"""
@@ -238,17 +257,30 @@ def predictSoft(self, x):
:return: membership_matrix. RDD of array of double values.
"""
if isinstance(x, RDD):
- means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians])
+ means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
- _convert_to_vector(self._weights), means, sigmas)
+ _convert_to_vector(self.weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
else:
raise TypeError("x should be represented by an RDD, "
"but got %s." % type(x))
+ @classmethod
+ def load(cls, sc, path):
+ """Load the GaussianMixtureModel from disk.
+
+ :param sc: SparkContext
+ :param path: str, path to where the model is stored.
+ """
+ model = cls._load_java(sc, path)
+ wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
+ return cls(wrapper)
+
class GaussianMixture(object):
"""
+ .. note:: Experimental
+
Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm.
:param data: RDD of data points
@@ -271,11 +303,10 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
initialModelWeights = initialModel.weights
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
- weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
- k, convergenceTol, maxIterations, seed,
- initialModelWeights, initialModelMu, initialModelSigma)
- mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)]
- return GaussianMixtureModel(weight, mvg_obj)
+ java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
+ k, convergenceTol, maxIterations, seed,
+ initialModelWeights, initialModelMu, initialModelSigma)
+ return GaussianMixtureModel(java_model)
class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg/__init__.py
similarity index 100%
rename from python/pyspark/mllib/linalg.py
rename to python/pyspark/mllib/linalg/__init__.py
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 875d3b2d642c6..916de2d6fcdbd 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -21,7 +21,9 @@
if sys.version > '3':
xrange = range
+ basestring = str
+from pyspark import SparkContext
from pyspark.mllib.common import callMLlibFunc, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
@@ -223,6 +225,10 @@ class JavaSaveable(Saveable):
"""
def save(self, sc, path):
+ if not isinstance(sc, SparkContext):
+ raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
self._java_model.save(sc._jsc.sc(), path)
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 8fb71bac64a5e..b8118bdb7ca76 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -606,7 +606,7 @@ def _open_file(self):
if not os.path.exists(d):
os.makedirs(d)
p = os.path.join(d, str(id(self)))
- self._file = open(p, "wb+", 65536)
+ self._file = open(p, "w+b", 65536)
self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
os.unlink(p)
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index abb6522dde7b0..917de24f3536b 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -277,6 +277,66 @@ def applySchema(self, rdd, schema):
return self.createDataFrame(rdd, schema)
+ def _createFromRDD(self, rdd, schema, samplingRatio):
+ """
+ Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
+ """
+ if schema is None or isinstance(schema, (list, tuple)):
+ struct = self._inferSchema(rdd, samplingRatio)
+ converter = _create_converter(struct)
+ rdd = rdd.map(converter)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ struct.names[i] = name
+ schema = struct
+
+ elif isinstance(schema, StructType):
+ # take the first few rows to verify schema
+ rows = rdd.take(10)
+ for row in rows:
+ _verify_type(row, schema)
+
+ else:
+ raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
+
+ # convert python objects to sql data
+ rdd = rdd.map(schema.toInternal)
+ return rdd, schema
+
+ def _createFromLocal(self, data, schema):
+ """
+ Create an RDD for DataFrame from an list or pandas.DataFrame, returns
+ the RDD and schema.
+ """
+ if has_pandas and isinstance(data, pandas.DataFrame):
+ if schema is None:
+ schema = [str(x) for x in data.columns]
+ data = [r.tolist() for r in data.to_records(index=False)]
+
+ # make sure data could consumed multiple times
+ if not isinstance(data, list):
+ data = list(data)
+
+ if schema is None or isinstance(schema, (list, tuple)):
+ struct = self._inferSchemaFromList(data)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ struct.names[i] = name
+ schema = struct
+
+ elif isinstance(schema, StructType):
+ for row in data:
+ _verify_type(row, schema)
+
+ else:
+ raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
+
+ # convert python objects to sql data
+ data = [schema.toInternal(row) for row in data]
+ return self._sc.parallelize(data), schema
+
@since(1.3)
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None):
@@ -340,49 +400,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
if isinstance(data, DataFrame):
raise TypeError("data is already a DataFrame")
- if has_pandas and isinstance(data, pandas.DataFrame):
- if schema is None:
- schema = [str(x) for x in data.columns]
- data = [r.tolist() for r in data.to_records(index=False)]
-
- if not isinstance(data, RDD):
- if not isinstance(data, list):
- data = list(data)
- try:
- # data could be list, tuple, generator ...
- rdd = self._sc.parallelize(data)
- except Exception:
- raise TypeError("cannot create an RDD from type: %s" % type(data))
+ if isinstance(data, RDD):
+ rdd, schema = self._createFromRDD(data, schema, samplingRatio)
else:
- rdd = data
-
- if schema is None or isinstance(schema, (list, tuple)):
- if isinstance(data, RDD):
- struct = self._inferSchema(rdd, samplingRatio)
- else:
- struct = self._inferSchemaFromList(data)
- if isinstance(schema, (list, tuple)):
- for i, name in enumerate(schema):
- struct.fields[i].name = name
- schema = struct
- converter = _create_converter(schema)
- rdd = rdd.map(converter)
-
- elif isinstance(schema, StructType):
- # take the first few rows to verify schema
- rows = rdd.take(10)
- for row in rows:
- _verify_type(row, schema)
-
- else:
- raise TypeError("schema should be StructType or list or None")
-
- # convert python objects to sql data
- rdd = rdd.map(schema.toInternal)
-
+ rdd, schema = self._createFromLocal(data, schema)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
- df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- return DataFrame(df, self)
+ jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+ df = DataFrame(jdf, self)
+ df._schema = schema
+ return df
@since(1.3)
def registerDataFrameAsTable(self, df, tableName):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d76e051bd73a1..0f3480c239187 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -441,6 +441,42 @@ def sample(self, withReplacement, fraction, seed=None):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
+ @since(1.5)
+ def sampleBy(self, col, fractions, seed=None):
+ """
+ Returns a stratified sample without replacement based on the
+ fraction given on each stratum.
+
+ :param col: column that defines strata
+ :param fractions:
+ sampling fraction for each stratum. If a stratum is not
+ specified, we treat its fraction as zero.
+ :param seed: random seed
+ :return: a new DataFrame that represents the stratified sample
+
+ >>> from pyspark.sql.functions import col
+ >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
+ >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
+ >>> sampled.groupBy("key").count().orderBy("key").show()
+ +---+-----+
+ |key|count|
+ +---+-----+
+ | 0| 3|
+ | 1| 8|
+ +---+-----+
+
+ """
+ if not isinstance(col, str):
+ raise ValueError("col must be a string, but got %r" % type(col))
+ if not isinstance(fractions, dict):
+ raise ValueError("fractions must be a dict but got %r" % type(fractions))
+ for k, v in fractions.items():
+ if not isinstance(k, (float, int, long, basestring)):
+ raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
+ fractions[k] = float(v)
+ seed = seed if seed is not None else random.randint(0, sys.maxsize)
+ return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
+
@since(1.4)
def randomSplit(self, weights, seed=None):
"""Randomly splits this :class:`DataFrame` with the provided weights.
@@ -1314,6 +1350,11 @@ def freqItems(self, cols, support=None):
freqItems.__doc__ = DataFrame.freqItems.__doc__
+ def sampleBy(self, col, fractions, seed=None):
+ return self.df.sampleBy(col, fractions, seed)
+
+ sampleBy.__doc__ = DataFrame.sampleBy.__doc__
+
def _test():
import doctest
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d930f7db25d25..8024a8de07c98 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -59,7 +59,7 @@
__all__ += ['lag', 'lead', 'ntile']
__all__ += [
- 'date_format',
+ 'date_format', 'date_add', 'date_sub', 'add_months', 'months_between',
'year', 'quarter', 'month', 'hour', 'minute', 'second',
'dayofmonth', 'dayofyear', 'weekofyear']
@@ -716,7 +716,7 @@ def date_format(dateCol, format):
[Row(date=u'04/08/2015')]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.date_format(dateCol, format))
+ return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format))
@since(1.5)
@@ -729,7 +729,7 @@ def year(col):
[Row(year=2015)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.year(col))
+ return Column(sc._jvm.functions.year(_to_java_column(col)))
@since(1.5)
@@ -742,7 +742,7 @@ def quarter(col):
[Row(quarter=2)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.quarter(col))
+ return Column(sc._jvm.functions.quarter(_to_java_column(col)))
@since(1.5)
@@ -755,7 +755,7 @@ def month(col):
[Row(month=4)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.month(col))
+ return Column(sc._jvm.functions.month(_to_java_column(col)))
@since(1.5)
@@ -768,7 +768,7 @@ def dayofmonth(col):
[Row(day=8)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.dayofmonth(col))
+ return Column(sc._jvm.functions.dayofmonth(_to_java_column(col)))
@since(1.5)
@@ -781,7 +781,7 @@ def dayofyear(col):
[Row(day=98)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.dayofyear(col))
+ return Column(sc._jvm.functions.dayofyear(_to_java_column(col)))
@since(1.5)
@@ -794,7 +794,7 @@ def hour(col):
[Row(hour=13)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.hour(col))
+ return Column(sc._jvm.functions.hour(_to_java_column(col)))
@since(1.5)
@@ -807,7 +807,7 @@ def minute(col):
[Row(minute=8)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.minute(col))
+ return Column(sc._jvm.functions.minute(_to_java_column(col)))
@since(1.5)
@@ -820,7 +820,7 @@ def second(col):
[Row(second=15)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.second(col))
+ return Column(sc._jvm.functions.second(_to_java_column(col)))
@since(1.5)
@@ -829,11 +829,93 @@ def weekofyear(col):
Extract the week number of a given date as integer.
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
- >>> df.select(weekofyear('a').alias('week')).collect()
+ >>> df.select(weekofyear(df.a).alias('week')).collect()
[Row(week=15)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.weekofyear(col))
+ return Column(sc._jvm.functions.weekofyear(_to_java_column(col)))
+
+
+@since(1.5)
+def date_add(start, days):
+ """
+ Returns the date that is `days` days after `start`
+
+ >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
+ >>> df.select(date_add(df.d, 1).alias('d')).collect()
+ [Row(d=datetime.date(2015, 4, 9))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
+
+
+@since(1.5)
+def date_sub(start, days):
+ """
+ Returns the date that is `days` days before `start`
+
+ >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
+ >>> df.select(date_sub(df.d, 1).alias('d')).collect()
+ [Row(d=datetime.date(2015, 4, 7))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))
+
+
+@since(1.5)
+def add_months(start, months):
+ """
+ Returns the date that is `months` months after `start`
+
+ >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
+ >>> df.select(add_months(df.d, 1).alias('d')).collect()
+ [Row(d=datetime.date(2015, 5, 8))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.add_months(_to_java_column(start), months))
+
+
+@since(1.5)
+def months_between(date1, date2):
+ """
+ Returns the number of months between date1 and date2.
+
+ >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd'])
+ >>> df.select(months_between(df.t, df.d).alias('months')).collect()
+ [Row(months=3.9495967...)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))
+
+
+@since(1.5)
+def to_date(col):
+ """
+ Converts the column of StringType or TimestampType into DateType.
+
+ >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+ >>> df.select(to_date(df.t).alias('date')).collect()
+ [Row(date=datetime.date(1997, 2, 28))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.to_date(_to_java_column(col)))
+
+
+@since(1.5)
+def trunc(date, format):
+ """
+ Returns date truncated to the unit specified by the format.
+
+ :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
+
+ >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d'])
+ >>> df.select(trunc(df.d, 'year').alias('year')).collect()
+ [Row(year=datetime.date(1997, 1, 1))]
+ >>> df.select(trunc(df.d, 'mon').alias('month')).collect()
+ [Row(month=datetime.date(1997, 2, 1))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
@since(1.5)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 5aa6135dc1ee7..ebd3ea8db6a43 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -75,7 +75,7 @@ def sqlType(self):
@classmethod
def module(cls):
- return 'pyspark.tests'
+ return 'pyspark.sql.tests'
@classmethod
def scalaUDT(cls):
@@ -106,10 +106,45 @@ def __str__(self):
return "(%s,%s)" % (self.x, self.y)
def __eq__(self, other):
- return isinstance(other, ExamplePoint) and \
+ return isinstance(other, self.__class__) and \
other.x == self.x and other.y == self.y
+class PythonOnlyUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return '__main__'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return PythonOnlyPoint(datum[0], datum[1])
+
+ @staticmethod
+ def foo():
+ pass
+
+ @property
+ def props(self):
+ return {}
+
+
+class PythonOnlyPoint(ExamplePoint):
+ """
+ An example class to demonstrate UDT in only Python
+ """
+ __UDT__ = PythonOnlyUDT()
+
+
class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
@@ -395,10 +430,39 @@ def test_convert_row_to_dict(self):
self.assertEqual(1, row.asDict()["l"][0].a)
self.assertEqual(1.0, row.asDict()['d']['key'].c)
+ def test_udt(self):
+ from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
+ from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
+
+ def check_datatype(datatype):
+ pickled = pickle.loads(pickle.dumps(datatype))
+ assert datatype == pickled
+ scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ python_datatype = _parse_datatype_json_string(scala_datatype.json())
+ assert datatype == python_datatype
+
+ check_datatype(ExamplePointUDT())
+ structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ check_datatype(structtype_with_udt)
+ p = ExamplePoint(1.0, 2.0)
+ self.assertEqual(_infer_type(p), ExamplePointUDT())
+ _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+ self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
+
+ check_datatype(PythonOnlyUDT())
+ structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ StructField("point", PythonOnlyUDT(), False)])
+ check_datatype(structtype_with_udt)
+ p = PythonOnlyPoint(1.0, 2.0)
+ self.assertEqual(_infer_type(p), PythonOnlyUDT())
+ _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
+ self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
+
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df = self.sc.parallelize([row]).toDF()
+ df = self.sqlCtx.createDataFrame([row])
schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
@@ -406,36 +470,66 @@ def test_infer_schema_with_udt(self):
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df = self.sqlCtx.createDataFrame([row])
+ schema = df.schema
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), PythonOnlyUDT)
+ df.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+ self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
def test_apply_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = (1.0, ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df = rdd.toDF(schema)
+ df = self.sqlCtx.createDataFrame([row], schema)
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
+ row = (1.0, PythonOnlyPoint(1.0, 2.0))
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", PythonOnlyUDT(), False)])
+ df = self.sqlCtx.createDataFrame([row], schema)
+ point = df.head().point
+ self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
+
def test_udf_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df = self.sc.parallelize([row]).toDF()
+ df = self.sqlCtx.createDataFrame([row])
self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df = self.sqlCtx.createDataFrame([row])
+ self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
+ udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+ self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+ udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
+ self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
+
def test_parquet_with_udt(self):
- from pyspark.sql.tests import ExamplePoint
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df0 = self.sc.parallelize([row]).toDF()
+ df0 = self.sqlCtx.createDataFrame([row])
output_dir = os.path.join(self.tempdir.name, "labeled_point")
- df0.saveAsParquetFile(output_dir)
+ df0.write.parquet(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
point = df1.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df0 = self.sqlCtx.createDataFrame([row])
+ df0.write.parquet(output_dir, mode='overwrite')
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ point = df1.head().point
+ self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
+
def test_column_operators(self):
ci = self.df.key
cs = self.df.value
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index b97d50c945f24..6f74b7162f7cc 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -22,6 +22,7 @@
import calendar
import json
import re
+import base64
from array import array
if sys.version >= "3":
@@ -31,6 +32,8 @@
from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass
+from pyspark.serializers import CloudPickleSerializer
+
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
@@ -458,7 +461,7 @@ def __init__(self, fields=None):
self.names = [f.name for f in fields]
assert all(isinstance(f, StructField) for f in fields),\
"fields should be a list of StructField"
- self._needSerializeFields = None
+ self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
def add(self, field, data_type=None, nullable=True, metadata=None):
"""
@@ -501,6 +504,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None):
data_type_f = data_type
self.fields.append(StructField(field, data_type_f, nullable, metadata))
self.names.append(field)
+ self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
return self
def simpleString(self):
@@ -526,12 +530,9 @@ def toInternal(self, obj):
if obj is None:
return
- if self._needSerializeFields is None:
- self._needSerializeFields = any(f.needConversion() for f in self.fields)
-
- if self._needSerializeFields:
+ if self._needSerializeAnyField:
if isinstance(obj, dict):
- return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields))
+ return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
elif isinstance(obj, (tuple, list)):
return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
else:
@@ -550,7 +551,10 @@ def fromInternal(self, obj):
if isinstance(obj, Row):
# it's already converted by pickler
return obj
- values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)]
+ if self._needSerializeAnyField:
+ values = [f.fromInternal(v) for f, v in zip(self.fields, obj)]
+ else:
+ values = obj
return _create_row(self.names, values)
@@ -581,9 +585,10 @@ def module(cls):
@classmethod
def scalaUDT(cls):
"""
- The class name of the paired Scala UDT.
+ The class name of the paired Scala UDT (could be '', if there
+ is no corresponding one).
"""
- raise NotImplementedError("UDT must have a paired Scala UDT.")
+ return ''
def needConversion(self):
return True
@@ -622,22 +627,37 @@ def json(self):
return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
def jsonValue(self):
- schema = {
- "type": "udt",
- "class": self.scalaUDT(),
- "pyClass": "%s.%s" % (self.module(), type(self).__name__),
- "sqlType": self.sqlType().jsonValue()
- }
+ if self.scalaUDT():
+ assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT'
+ schema = {
+ "type": "udt",
+ "class": self.scalaUDT(),
+ "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+ "sqlType": self.sqlType().jsonValue()
+ }
+ else:
+ ser = CloudPickleSerializer()
+ b = ser.dumps(type(self))
+ schema = {
+ "type": "udt",
+ "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+ "serializedClass": base64.b64encode(b).decode('utf8'),
+ "sqlType": self.sqlType().jsonValue()
+ }
return schema
@classmethod
def fromJson(cls, json):
- pyUDT = json["pyClass"]
+ pyUDT = str(json["pyClass"]) # convert unicode to str
split = pyUDT.rfind(".")
pyModule = pyUDT[:split]
pyClass = pyUDT[split+1:]
m = __import__(pyModule, globals(), locals(), [pyClass])
- UDT = getattr(m, pyClass)
+ if not hasattr(m, pyClass):
+ s = base64.b64decode(json['serializedClass'].encode('utf-8'))
+ UDT = CloudPickleSerializer().loads(s)
+ else:
+ UDT = getattr(m, pyClass)
return UDT()
def __eq__(self, other):
@@ -696,11 +716,6 @@ def _parse_datatype_json_string(json_string):
>>> complex_maptype = MapType(complex_structtype,
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
-
- >>> check_datatype(ExamplePointUDT())
- >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
- ... StructField("point", ExamplePointUDT(), False)])
- >>> check_datatype(structtype_with_udt)
"""
return _parse_datatype_json_value(json.loads(json_string))
@@ -752,10 +767,6 @@ def _parse_datatype_json_value(json_value):
def _infer_type(obj):
"""Infer the DataType from obj
-
- >>> p = ExamplePoint(1.0, 2.0)
- >>> _infer_type(p)
- ExamplePointUDT
"""
if obj is None:
return NullType()
@@ -1090,11 +1101,6 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
- >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
- >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError:...
"""
# all objects are nullable
if obj is None:
@@ -1259,18 +1265,12 @@ def convert(self, obj, gateway_client):
def _test():
import doctest
from pyspark.context import SparkContext
- # let doctest run in pyspark.sql.types, so DataTypes can be picklable
- import pyspark.sql.types
- from pyspark.sql import Row, SQLContext
- from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
- globs = pyspark.sql.types.__dict__.copy()
+ from pyspark.sql import SQLContext
+ globs = globals()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
- globs['ExamplePoint'] = ExamplePoint
- globs['ExamplePointUDT'] = ExamplePointUDT
- (failure_count, test_count) = doctest.testmod(
- pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
new file mode 100644
index 0000000000000..e3d3ba7a9ccc0
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.ArrayData;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+public interface SpecializedGetters {
+
+ boolean isNullAt(int ordinal);
+
+ boolean getBoolean(int ordinal);
+
+ byte getByte(int ordinal);
+
+ short getShort(int ordinal);
+
+ int getInt(int ordinal);
+
+ long getLong(int ordinal);
+
+ float getFloat(int ordinal);
+
+ double getDouble(int ordinal);
+
+ Decimal getDecimal(int ordinal, int precision, int scale);
+
+ UTF8String getUTF8String(int ordinal);
+
+ byte[] getBinary(int ordinal);
+
+ CalendarInterval getInterval(int ordinal);
+
+ InternalRow getStruct(int ordinal, int numFields);
+
+ ArrayData getArray(int ordinal);
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 684de6e81d67c..f3b462778dc10 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -20,6 +20,8 @@
import java.util.Iterator;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
@@ -61,26 +63,18 @@ public final class UnsafeFixedWidthAggregationMap {
private final boolean enablePerfMetrics;
- /**
- * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
- * false otherwise.
- */
- public static boolean supportsGroupKeySchema(StructType schema) {
- for (StructField field: schema.fields()) {
- if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
- return false;
- }
- }
- return true;
- }
-
/**
* @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
* schema, false otherwise.
*/
public static boolean supportsAggregationBufferSchema(StructType schema) {
for (StructField field: schema.fields()) {
- if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+ if (field.dataType() instanceof DecimalType) {
+ DecimalType dt = (DecimalType) field.dataType();
+ if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
+ return false;
+ }
+ } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
return false;
}
}
@@ -95,6 +89,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) {
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
* @param memoryManager the memory manager used to allocate our Unsafe memory structures.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
+ * @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
*/
public UnsafeFixedWidthAggregationMap(
@@ -103,11 +98,13 @@ public UnsafeFixedWidthAggregationMap(
StructType groupingKeySchema,
TaskMemoryManager memoryManager,
int initialCapacity,
+ long pageSizeBytes,
boolean enablePerfMetrics) {
this.aggregationBufferSchema = aggregationBufferSchema;
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
- this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+ this.map =
+ new BytesToBytesMap(memoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
this.enablePerfMetrics = enablePerfMetrics;
// Initialize the buffer for aggregation value
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index fb084dd13b620..e7088edced1a1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -19,6 +19,8 @@
import java.io.IOException;
import java.io.OutputStream;
+import java.math.BigDecimal;
+import java.math.BigInteger;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
@@ -29,7 +31,7 @@
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
-import org.apache.spark.unsafe.types.Interval;
+import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.types.DataTypes.*;
@@ -65,12 +67,7 @@ public static int calculateBitSetWidthInBytes(int numFields) {
*/
public static final Set settableFieldTypes;
- /**
- * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
- */
- public static final Set readableFieldTypes;
-
- // TODO: support DecimalType
+ // DecimalType(precision <= 18) is settable
static {
settableFieldTypes = Collections.unmodifiableSet(
new HashSet<>(
@@ -86,16 +83,6 @@ public static int calculateBitSetWidthInBytes(int numFields) {
DateType,
TimestampType
})));
-
- // We support get() on a superset of the types for which we support set():
- final Set _readableFieldTypes = new HashSet<>(
- Arrays.asList(new DataType[]{
- StringType,
- BinaryType,
- IntervalType
- }));
- _readableFieldTypes.addAll(settableFieldTypes);
- readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
}
//////////////////////////////////////////////////////////////////////////////
@@ -232,6 +219,21 @@ public void setFloat(int ordinal, float value) {
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}
+ @Override
+ public void setDecimal(int ordinal, Decimal value, int precision) {
+ assertIndexIsValid(ordinal);
+ if (value == null) {
+ setNullAt(ordinal);
+ } else {
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ setLong(ordinal, value.toUnscaledLong());
+ } else {
+ // TODO(davies): support update decimal (hold a bounded space even it's null)
+ throw new UnsupportedOperationException();
+ }
+ }
+ }
+
@Override
public Object get(int ordinal) {
throw new UnsupportedOperationException();
@@ -239,7 +241,7 @@ public Object get(int ordinal) {
@Override
public Object get(int ordinal, DataType dataType) {
- if (dataType instanceof NullType) {
+ if (isNullAt(ordinal) || dataType instanceof NullType) {
return null;
} else if (dataType instanceof BooleanType) {
return getBoolean(ordinal);
@@ -256,7 +258,8 @@ public Object get(int ordinal, DataType dataType) {
} else if (dataType instanceof DoubleType) {
return getDouble(ordinal);
} else if (dataType instanceof DecimalType) {
- return getDecimal(ordinal);
+ DecimalType dt = (DecimalType) dataType;
+ return getDecimal(ordinal, dt.precision(), dt.scale());
} else if (dataType instanceof DateType) {
return getInt(ordinal);
} else if (dataType instanceof TimestampType) {
@@ -265,6 +268,8 @@ public Object get(int ordinal, DataType dataType) {
return getBinary(ordinal);
} else if (dataType instanceof StringType) {
return getUTF8String(ordinal);
+ } else if (dataType instanceof CalendarIntervalType) {
+ return getInterval(ordinal);
} else if (dataType instanceof StructType) {
return getStruct(ordinal, ((StructType) dataType).size());
} else {
@@ -311,20 +316,28 @@ public long getLong(int ordinal) {
@Override
public float getFloat(int ordinal) {
assertIndexIsValid(ordinal);
- if (isNullAt(ordinal)) {
- return Float.NaN;
- } else {
- return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal));
- }
+ return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal));
}
@Override
public double getDouble(int ordinal) {
+ assertIndexIsValid(ordinal);
+ return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal));
+ }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
assertIndexIsValid(ordinal);
if (isNullAt(ordinal)) {
- return Float.NaN;
+ return null;
+ }
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.apply(getLong(ordinal), precision, scale);
} else {
- return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal));
+ byte[] bytes = getBinary(ordinal);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale);
}
}
@@ -356,7 +369,7 @@ public byte[] getBinary(int ordinal) {
}
@Override
- public Interval getInterval(int ordinal) {
+ public CalendarInterval getInterval(int ordinal) {
if (isNullAt(ordinal)) {
return null;
} else {
@@ -365,7 +378,7 @@ public Interval getInterval(int ordinal) {
final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset);
final long microseconds =
PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8);
- return new Interval(months, microseconds);
+ return new CalendarInterval(months, microseconds);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index 0ba31d3b9b743..f43a285cd6cad 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.expressions;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.ByteArray;
-import org.apache.spark.unsafe.types.Interval;
+import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
/**
@@ -29,6 +31,47 @@
*/
public class UnsafeRowWriters {
+ /** Writer for Decimal with precision under 18. */
+ public static class CompactDecimalWriter {
+
+ public static int getSize(Decimal input) {
+ return 0;
+ }
+
+ public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
+ target.setLong(ordinal, input.toUnscaledLong());
+ return 0;
+ }
+ }
+
+ /** Writer for Decimal with precision larger than 18. */
+ public static class DecimalWriter {
+
+ public static int getSize(Decimal input) {
+ // bounded size
+ return 16;
+ }
+
+ public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
+ final long offset = target.getBaseOffset() + cursor;
+ final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+ final int numBytes = bytes.length;
+ assert(numBytes <= 16);
+
+ // zero-out the bytes
+ PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L);
+ PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L);
+
+ // Write the bytes to the variable length portion.
+ PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET,
+ target.getBaseObject(), offset, numBytes);
+
+ // Set the fixed length portion.
+ target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
+ return 16;
+ }
+ }
+
/** Writer for UTF8String. */
public static class UTF8StringWriter {
@@ -46,7 +89,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String in
target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
}
- // Write the string to the variable length portion.
+ // Write the bytes to the variable length portion.
input.writeToMemory(target.getBaseObject(), offset);
// Set the fixed length portion.
@@ -72,7 +115,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input)
target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
}
- // Write the string to the variable length portion.
+ // Write the bytes to the variable length portion.
ByteArray.writeToMemory(input, target.getBaseObject(), offset);
// Set the fixed length portion.
@@ -81,10 +124,56 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input)
}
}
+ /**
+ * Writer for struct type where the struct field is backed by an {@link UnsafeRow}.
+ *
+ * We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}.
+ * Non-UnsafeRow struct fields are handled directly in
+ * {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}
+ * by generating the Java code needed to convert them into UnsafeRow.
+ */
+ public static class StructWriter {
+ public static int getSize(InternalRow input) {
+ int numBytes = 0;
+ if (input instanceof UnsafeRow) {
+ numBytes = ((UnsafeRow) input).getSizeInBytes();
+ } else {
+ // This is handled directly in GenerateUnsafeProjection.
+ throw new UnsupportedOperationException();
+ }
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+ }
+
+ public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) {
+ int numBytes = 0;
+ final long offset = target.getBaseOffset() + cursor;
+ if (input instanceof UnsafeRow) {
+ final UnsafeRow row = (UnsafeRow) input;
+ numBytes = row.getSizeInBytes();
+
+ // zero-out the padding bytes
+ if ((numBytes & 0x07) > 0) {
+ PlatformDependent.UNSAFE.putLong(
+ target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
+ }
+
+ // Write the bytes to the variable length portion.
+ row.writeToMemory(target.getBaseObject(), offset);
+
+ // Set the fixed length portion.
+ target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
+ } else {
+ // This is handled directly in GenerateUnsafeProjection.
+ throw new UnsupportedOperationException();
+ }
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+ }
+ }
+
/** Writer for interval type. */
public static class IntervalWriter {
- public static int write(UnsafeRow target, int ordinal, int cursor, Interval input) {
+ public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInterval input) {
final long offset = target.getBaseOffset() + cursor;
// Write the months and microseconds fields of Interval to the variable length portion.
@@ -96,5 +185,4 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Interval inpu
return 16;
}
}
-
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 4c3f2c6557140..68c49feae938e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -48,7 +48,6 @@ final class UnsafeExternalRowSorter {
private long numRowsInserted = 0;
private final StructType schema;
- private final UnsafeProjection unsafeProjection;
private final PrefixComputer prefixComputer;
private final UnsafeExternalSorter sorter;
@@ -62,7 +61,6 @@ public UnsafeExternalRowSorter(
PrefixComparator prefixComparator,
PrefixComputer prefixComputer) throws IOException {
this.schema = schema;
- this.unsafeProjection = UnsafeProjection.create(schema);
this.prefixComputer = prefixComputer;
final SparkEnv sparkEnv = SparkEnv.get();
final TaskContext taskContext = TaskContext.get();
@@ -88,13 +86,12 @@ void setTestSpillFrequency(int frequency) {
}
@VisibleForTesting
- void insertRow(InternalRow row) throws IOException {
- UnsafeRow unsafeRow = unsafeProjection.apply(row);
+ void insertRow(UnsafeRow row) throws IOException {
final long prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
- unsafeRow.getBaseObject(),
- unsafeRow.getBaseOffset(),
- unsafeRow.getSizeInBytes(),
+ row.getBaseObject(),
+ row.getBaseOffset(),
+ row.getSizeInBytes(),
prefix
);
numRowsInserted++;
@@ -113,7 +110,7 @@ private void cleanupResources() {
}
@VisibleForTesting
- Iterator sort() throws IOException {
+ Iterator sort() throws IOException {
try {
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
if (!sortedIterator.hasNext()) {
@@ -121,7 +118,7 @@ Iterator sort() throws IOException {
// here in order to prevent memory leaks.
cleanupResources();
}
- return new AbstractScalaRowIterator() {
+ return new AbstractScalaRowIterator() {
private final int numFields = schema.length();
private UnsafeRow row = new UnsafeRow();
@@ -132,7 +129,7 @@ public boolean hasNext() {
}
@Override
- public InternalRow next() {
+ public UnsafeRow next() {
try {
sortedIterator.loadNext();
row.pointTo(
@@ -164,11 +161,11 @@ public InternalRow next() {
}
- public Iterator sort(Iterator inputIterator) throws IOException {
- while (inputIterator.hasNext()) {
- insertRow(inputIterator.next());
- }
- return sort();
+ public Iterator sort(Iterator inputIterator) throws IOException {
+ while (inputIterator.hasNext()) {
+ insertRow(inputIterator.next());
+ }
+ return sort();
}
/**
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java
index 5703de42393de..17659d7d960b0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java
@@ -50,9 +50,9 @@ public class DataTypes {
public static final DataType TimestampType = TimestampType$.MODULE$;
/**
- * Gets the IntervalType object.
+ * Gets the CalendarIntervalType object.
*/
- public static final DataType IntervalType = IntervalType$.MODULE$;
+ public static final DataType CalendarIntervalType = CalendarIntervalType$.MODULE$;
/**
* Gets the DoubleType object.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index d1d89a1f48329..7ca20fe97fbef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -55,7 +55,6 @@ object CatalystTypeConverters {
private def isWholePrimitive(dt: DataType): Boolean = dt match {
case dt if isPrimitive(dt) => true
- case ArrayType(elementType, _) => isWholePrimitive(elementType)
case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
case _ => false
}
@@ -69,7 +68,7 @@ object CatalystTypeConverters {
case StringType => StringConverter
case DateType => DateConverter
case TimestampType => TimestampConverter
- case dt: DecimalType => BigDecimalConverter
+ case dt: DecimalType => new DecimalConverter(dt)
case BooleanType => BooleanConverter
case ByteType => ByteConverter
case ShortType => ShortConverter
@@ -154,39 +153,41 @@ object CatalystTypeConverters {
/** Converter for arrays, sequences, and Java iterables. */
private case class ArrayConverter(
- elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] {
+ elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] {
private[this] val elementConverter = getConverterForType(elementType)
private[this] val isNoChange = isWholePrimitive(elementType)
- override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
+ override def toCatalystImpl(scalaValue: Any): ArrayData = {
scalaValue match {
- case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
- case s: Seq[_] => s.map(elementConverter.toCatalyst)
+ case a: Array[_] =>
+ new GenericArrayData(a.map(elementConverter.toCatalyst))
+ case s: Seq[_] =>
+ new GenericArrayData(s.map(elementConverter.toCatalyst).toArray)
case i: JavaIterable[_] =>
val iter = i.iterator
- var convertedIterable: List[Any] = List()
+ val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any]
while (iter.hasNext) {
val item = iter.next()
- convertedIterable :+= elementConverter.toCatalyst(item)
+ convertedIterable += elementConverter.toCatalyst(item)
}
- convertedIterable
+ new GenericArrayData(convertedIterable.toArray)
}
}
- override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
+ override def toScala(catalystValue: ArrayData): Seq[Any] = {
if (catalystValue == null) {
null
} else if (isNoChange) {
- catalystValue
+ catalystValue.toArray()
} else {
- catalystValue.map(elementConverter.toScala)
+ catalystValue.toArray().map(elementConverter.toScala)
}
}
override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] =
- toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]])
+ toScala(row.getArray(column))
}
private case class MapConverter(
@@ -305,7 +306,8 @@ object CatalystTypeConverters {
DateTimeUtils.toJavaTimestamp(row.getLong(column))
}
- private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
+ private class DecimalConverter(dataType: DecimalType)
+ extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
case d: BigDecimal => Decimal(d)
case d: JavaBigDecimal => Decimal(d)
@@ -313,9 +315,11 @@ object CatalystTypeConverters {
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
- row.getDecimal(column).toJavaBigDecimal
+ row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal
}
+ private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT)
+
private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
final override def toScala(catalystValue: Any): Any = catalystValue
final override def toCatalystImpl(scalaValue: T): Any = scalaValue
@@ -402,9 +406,9 @@ object CatalystTypeConverters {
case t: Timestamp => TimestampConverter.toCatalyst(t)
case d: BigDecimal => BigDecimalConverter.toCatalyst(d)
case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d)
- case seq: Seq[Any] => seq.map(convertToCatalyst)
+ case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*)
- case arr: Array[Any] => arr.map(convertToCatalyst)
+ case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst))
case m: Map[_, _] =>
m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
case other => other
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 9a11de3840ce2..b19bf4386b0ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{Interval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
* internal types.
*/
-abstract class InternalRow extends Serializable {
+abstract class InternalRow extends Serializable with SpecializedGetters {
def numFields: Int
@@ -38,29 +38,31 @@ abstract class InternalRow extends Serializable {
def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T]
- def isNullAt(ordinal: Int): Boolean = get(ordinal) == null
+ override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null
- def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType)
+ override def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType)
- def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType)
+ override def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType)
- def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType)
+ override def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType)
- def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType)
+ override def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType)
- def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType)
+ override def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType)
- def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType)
+ override def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType)
- def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType)
+ override def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType)
- def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType)
+ override def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType)
- def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType)
+ override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType)
- def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT)
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
+ getAs[Decimal](ordinal, DecimalType(precision, scale))
- def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType)
+ override def getInterval(ordinal: Int): CalendarInterval =
+ getAs[CalendarInterval](ordinal, CalendarIntervalType)
// This is only use for test and will throw a null pointer exception if the position is null.
def getString(ordinal: Int): String = getUTF8String(ordinal).toString
@@ -71,7 +73,10 @@ abstract class InternalRow extends Serializable {
* @param ordinal position to get the struct from.
* @param numFields number of fields the struct type has
*/
- def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null)
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow =
+ getAs[InternalRow](ordinal, null)
+
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null)
override def toString: String = s"[${this.mkString(",")}]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index b423f0fa04f69..f2498861c9573 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.Interval
+import org.apache.spark.unsafe.types.CalendarInterval
/**
* A very simple SQL parser. Based loosely on:
@@ -332,8 +332,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val numericLiteral: Parser[Literal] =
( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
| sign.? ~ unsignedFloat ^^ {
- // TODO(davies): some precisions may loss, we should create decimal literal
- case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue())
+ case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f))
}
)
@@ -366,32 +365,32 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val millisecond: Parser[Long] =
integral <~ intervalUnit("millisecond") ^^ {
- case num => num.toLong * Interval.MICROS_PER_MILLI
+ case num => num.toLong * CalendarInterval.MICROS_PER_MILLI
}
protected lazy val second: Parser[Long] =
integral <~ intervalUnit("second") ^^ {
- case num => num.toLong * Interval.MICROS_PER_SECOND
+ case num => num.toLong * CalendarInterval.MICROS_PER_SECOND
}
protected lazy val minute: Parser[Long] =
integral <~ intervalUnit("minute") ^^ {
- case num => num.toLong * Interval.MICROS_PER_MINUTE
+ case num => num.toLong * CalendarInterval.MICROS_PER_MINUTE
}
protected lazy val hour: Parser[Long] =
integral <~ intervalUnit("hour") ^^ {
- case num => num.toLong * Interval.MICROS_PER_HOUR
+ case num => num.toLong * CalendarInterval.MICROS_PER_HOUR
}
protected lazy val day: Parser[Long] =
integral <~ intervalUnit("day") ^^ {
- case num => num.toLong * Interval.MICROS_PER_DAY
+ case num => num.toLong * CalendarInterval.MICROS_PER_DAY
}
protected lazy val week: Parser[Long] =
integral <~ intervalUnit("week") ^^ {
- case num => num.toLong * Interval.MICROS_PER_WEEK
+ case num => num.toLong * CalendarInterval.MICROS_PER_WEEK
}
protected lazy val intervalLiteral: Parser[Literal] =
@@ -407,7 +406,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
val months = Seq(year, month).map(_.getOrElse(0)).sum
val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond)
.map(_.getOrElse(0L)).sum
- Literal.create(new Interval(months, microseconds), IntervalType)
+ Literal.create(new CalendarInterval(months, microseconds), CalendarIntervalType)
}
private def toNarrowestIntegerType(value: String): Any = {
@@ -420,6 +419,17 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
}
}
+ private def toDecimalOrDouble(value: String): Any = {
+ val decimal = BigDecimal(value)
+ // follow the behavior in MS SQL Server
+ // https://msdn.microsoft.com/en-us/library/ms179899.aspx
+ if (value.contains('E') || value.contains('e')) {
+ decimal.doubleValue()
+ } else {
+ decimal.underlying()
+ }
+ }
+
protected lazy val baseExpression: Parser[Expression] =
( "*" ^^^ UnresolvedStar(None)
| ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a723e92114b32..265f3d1e41765 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.expressions._
@@ -25,7 +27,6 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.types._
-import scala.collection.mutable.ArrayBuffer
/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -78,6 +79,7 @@ class Analyzer(
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
+ RemoveEvaluationFromSort ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
@@ -927,12 +929,17 @@ class Analyzer(
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
- val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e =>
- val ne = e match {
- case n: NamedExpression => n
- case _ => Alias(e, "_nondeterministic")()
+ val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr =>
+ val leafNondeterministic = expr.collect {
+ case n: Nondeterministic => n
+ }
+ leafNondeterministic.map { e =>
+ val ne = e match {
+ case n: NamedExpression => n
+ case _ => Alias(e, "_nondeterministic")()
+ }
+ new TreeNodeRef(e) -> ne
}
- new TreeNodeRef(e) -> ne
}.toMap
val newPlan = p.transformExpressions { case e =>
nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
@@ -941,6 +948,63 @@ class Analyzer(
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
}
}
+
+ /**
+ * Removes all still-need-evaluate ordering expressions from sort and use an inner project to
+ * materialize them, finally use a outer project to project them away to keep the result same.
+ * Then we can make sure we only sort by [[AttributeReference]]s.
+ *
+ * As an example,
+ * {{{
+ * Sort('a, 'b + 1,
+ * Relation('a, 'b))
+ * }}}
+ * will be turned into:
+ * {{{
+ * Project('a, 'b,
+ * Sort('a, '_sortCondition,
+ * Project('a, 'b, ('b + 1).as("_sortCondition"),
+ * Relation('a, 'b))))
+ * }}}
+ */
+ object RemoveEvaluationFromSort extends Rule[LogicalPlan] {
+ private def hasAlias(expr: Expression) = {
+ expr.find {
+ case a: Alias => true
+ case _ => false
+ }.isDefined
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // The ordering expressions have no effect to the output schema of `Sort`,
+ // so `Alias`s in ordering expressions are unnecessary and we should remove them.
+ case s @ Sort(ordering, _, _) if ordering.exists(hasAlias) =>
+ val newOrdering = ordering.map(_.transformUp {
+ case Alias(child, _) => child
+ }.asInstanceOf[SortOrder])
+ s.copy(order = newOrdering)
+
+ case s @ Sort(ordering, global, child)
+ if s.expressions.forall(_.resolved) && s.childrenResolved && !s.hasNoEvaluation =>
+
+ val (ref, needEval) = ordering.partition(_.child.isInstanceOf[AttributeReference])
+
+ val namedExpr = needEval.map(_.child match {
+ case n: NamedExpression => n
+ case e => Alias(e, "_sortCondition")()
+ })
+
+ val newOrdering = ref ++ needEval.zip(namedExpr).map { case (order, ne) =>
+ order.copy(child = ne.toAttribute)
+ }
+
+ // Add still-need-evaluate ordering expressions into inner project and then project
+ // them away after the sort.
+ Project(child.output,
+ Sort(newOrdering, global,
+ Project(child.output ++ namedExpr, child)))
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index a373714832962..0ebc3d180a780 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -87,6 +87,18 @@ trait CheckAnalysis {
s"join condition '${condition.prettyString}' " +
s"of type ${condition.dataType.simpleString} is not a boolean.")
+ case j @ Join(_, _, _, Some(condition)) =>
+ def checkValidJoinConditionExprs(expr: Expression): Unit = expr match {
+ case p: Predicate =>
+ p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
+ case e if e.dataType.isInstanceOf[BinaryType] =>
+ failAnalysis(s"expression ${e.prettyString} in join condition " +
+ s"'${condition.prettyString}' can't be binary type.")
+ case _ => // OK
+ }
+
+ checkValidJoinConditionExprs(condition)
+
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
@@ -100,7 +112,15 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
+ def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match {
+ case BinaryType =>
+ failAnalysis(s"grouping expression '${expr.prettyString}' in aggregate can " +
+ s"not be binary type.")
+ case _ => // OK
+ }
+
aggregateExprs.foreach(checkValidAggregateExpression)
+ aggregateExprs.foreach(checkValidGroupingExprs)
case Sort(orders, _, _) =>
orders.foreach { order =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index aa05f448d12bc..1bf7204a2515c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -161,13 +161,6 @@ object FunctionRegistry {
expression[ToDegrees]("degrees"),
expression[ToRadians]("radians"),
- // misc functions
- expression[Md5]("md5"),
- expression[Sha2]("sha2"),
- expression[Sha1]("sha1"),
- expression[Sha1]("sha"),
- expression[Crc32]("crc32"),
-
// aggregate functions
expression[Average]("avg"),
expression[Count]("count"),
@@ -212,22 +205,41 @@ object FunctionRegistry {
expression[Upper]("upper"),
// datetime functions
+ expression[AddMonths]("add_months"),
expression[CurrentDate]("current_date"),
expression[CurrentTimestamp]("current_timestamp"),
+ expression[DateAdd]("date_add"),
expression[DateFormatClass]("date_format"),
+ expression[DateSub]("date_sub"),
expression[DayOfMonth]("day"),
expression[DayOfYear]("dayofyear"),
expression[DayOfMonth]("dayofmonth"),
+ expression[FromUnixTime]("from_unixtime"),
expression[Hour]("hour"),
- expression[Month]("month"),
+ expression[LastDay]("last_day"),
expression[Minute]("minute"),
+ expression[Month]("month"),
+ expression[MonthsBetween]("months_between"),
+ expression[NextDay]("next_day"),
expression[Quarter]("quarter"),
expression[Second]("second"),
+ expression[ToDate]("to_date"),
+ expression[TruncDate]("trunc"),
+ expression[UnixTimestamp]("unix_timestamp"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
// collection functions
- expression[Size]("size")
+ expression[Size]("size"),
+
+ // misc functions
+ expression[Crc32]("crc32"),
+ expression[Md5]("md5"),
+ expression[Sha1]("sha"),
+ expression[Sha1]("sha1"),
+ expression[Sha2]("sha2"),
+ expression[SparkPartitionID]("spark_partition_id"),
+ expression[InputFileName]("input_file_name")
)
val builtin: FunctionRegistry = {
@@ -237,7 +249,7 @@ object FunctionRegistry {
}
/** See usage above. */
- private def expression[T <: Expression](name: String)
+ def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
// See if we can find a constructor that accepts Seq[Expression]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e0527503442f0..603afc4032a37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -47,6 +47,7 @@ object HiveTypeCoercion {
Division ::
PropagateTypes ::
ImplicitTypeCasts ::
+ DateTimeOperations ::
Nil
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
@@ -109,13 +110,35 @@ object HiveTypeCoercion {
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
*/
- private def findTightestCommonType(types: Seq[DataType]) = {
+ private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) => findTightestCommonTypeOfTwo(d, c)
})
}
+ private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match {
+ case (t1: DecimalType, t2: DecimalType) =>
+ Some(DecimalPrecision.widerDecimalType(t1, t2))
+ case (t: IntegralType, d: DecimalType) =>
+ Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+ case (d: DecimalType, t: IntegralType) =>
+ Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+ case (t: FractionalType, d: DecimalType) =>
+ Some(DoubleType)
+ case (d: DecimalType, t: FractionalType) =>
+ Some(DoubleType)
+ case _ =>
+ findTightestCommonTypeToString(t1, t2)
+ }
+
+ private def findWiderCommonType(types: Seq[DataType]) = {
+ types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
+ case Some(d) => findWiderTypeForTwo(d, c)
+ case None => None
+ })
+ }
+
/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
@@ -182,20 +205,7 @@ object HiveTypeCoercion {
val castedTypes = left.output.zip(right.output).map {
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
- (lhs.dataType, rhs.dataType) match {
- case (t1: DecimalType, t2: DecimalType) =>
- Some(DecimalPrecision.widerDecimalType(t1, t2))
- case (t: IntegralType, d: DecimalType) =>
- Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
- case (d: DecimalType, t: IntegralType) =>
- Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
- case (t: FractionalType, d: DecimalType) =>
- Some(DoubleType)
- case (d: DecimalType, t: FractionalType) =>
- Some(DoubleType)
- case _ =>
- findTightestCommonTypeToString(lhs.dataType, rhs.dataType)
- }
+ findWiderTypeForTwo(lhs.dataType, rhs.dataType)
case other => None
}
@@ -236,8 +246,13 @@ object HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a @ BinaryArithmetic(left @ StringType(), r) =>
- a.makeCopy(Array(Cast(left, DoubleType), r))
+ case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) =>
+ a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right))
+ case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) =>
+ a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT)))
+
+ case a @ BinaryArithmetic(left @ StringType(), right) =>
+ a.makeCopy(Array(Cast(left, DoubleType), right))
case a @ BinaryArithmetic(left, right @ StringType()) =>
a.makeCopy(Array(left, Cast(right, DoubleType)))
@@ -543,7 +558,7 @@ object HiveTypeCoercion {
// compatible with every child column.
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val types = es.map(_.dataType)
- findTightestCommonTypeAndPromoteToString(types) match {
+ findWiderCommonType(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None => c
}
@@ -624,6 +639,27 @@ object HiveTypeCoercion {
}
}
+ /**
+ * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
+ * to TimeAdd/TimeSub
+ */
+ object DateTimeOperations extends Rule[LogicalPlan] {
+
+ private val acceptedTypes = Seq(DateType, TimestampType, StringType)
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ case Add(l @ CalendarIntervalType(), r) if acceptedTypes.contains(r.dataType) =>
+ Cast(TimeAdd(r, l), r.dataType)
+ case Add(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
+ Cast(TimeAdd(l, r), l.dataType)
+ case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
+ Cast(TimeSub(l, r), l.dataType)
+ }
+ }
+
/**
* Casts types according to the expected input types for [[Expression]]s.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 41a877f214e55..45709c1c8f554 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -48,9 +48,9 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
case DoubleType => input.getDouble(ordinal)
case StringType => input.getUTF8String(ordinal)
case BinaryType => input.getBinary(ordinal)
- case IntervalType => input.getInterval(ordinal)
+ case CalendarIntervalType => input.getInterval(ordinal)
case t: StructType => input.getStruct(ordinal, t.size)
- case dataType => input.get(ordinal, dataType)
+ case _ => input.get(ordinal, dataType)
}
}
}
@@ -64,10 +64,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def exprId: ExprId = throw new UnsupportedOperationException
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val javaType = ctx.javaType(dataType)
+ val value = ctx.getValue("i", dataType, ordinal.toString)
s"""
- boolean ${ev.isNull} = i.isNullAt($ordinal);
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
- ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)});
+ boolean ${ev.isNull} = i.isNullAt($ordinal);
+ $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
"""
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index bd8b0177eb00e..43be11c48ae7c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{Interval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import scala.collection.mutable
@@ -55,7 +55,7 @@ object Cast {
case (_, DateType) => true
- case (StringType, IntervalType) => true
+ case (StringType, CalendarIntervalType) => true
case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
@@ -225,7 +225,7 @@ case class Cast(child: Expression, dataType: DataType)
// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => Interval.fromString(s.toString))
+ buildCast[UTF8String](_, s => CalendarInterval.fromString(s.toString))
case _ => _ => null
}
@@ -363,7 +363,21 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = {
val elementCast = cast(from.elementType, to.elementType)
- buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v)))
+ // TODO: Could be faster?
+ buildCast[ArrayData](_, array => {
+ val length = array.numElements()
+ val values = new Array[Any](length)
+ var i = 0
+ while (i < length) {
+ if (array.isNullAt(i)) {
+ values(i) = null
+ } else {
+ values(i) = elementCast(array.get(i))
+ }
+ i += 1
+ }
+ new GenericArrayData(values)
+ })
}
private[this] def castMap(from: MapType, to: MapType): Any => Any = {
@@ -398,7 +412,7 @@ case class Cast(child: Expression, dataType: DataType)
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
- case IntervalType => castToInterval(from)
+ case CalendarIntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
@@ -438,7 +452,7 @@ case class Cast(child: Expression, dataType: DataType)
case DateType => castToDateCode(from, ctx)
case decimal: DecimalType => castToDecimalCode(from, decimal)
case TimestampType => castToTimestampCode(from, ctx)
- case IntervalType => castToIntervalCode(from)
+ case CalendarIntervalType => castToIntervalCode(from)
case BooleanType => castToBooleanCode(from)
case ByteType => castToByteCode(from)
case ShortType => castToShortCode(from)
@@ -599,7 +613,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"
case _: IntegralType =>
(c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};"
case DateType =>
@@ -630,7 +644,7 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castToIntervalCode(from: DataType): CastFunction = from match {
case StringType =>
(c, evPrim, evNull) =>
- s"$evPrim = Interval.fromString($c.toString());"
+ s"$evPrim = CalendarInterval.fromString($c.toString());"
}
private[this] def decimalToTimestampCode(d: String): String =
@@ -665,7 +679,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
@@ -687,7 +701,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
@@ -731,7 +745,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
@@ -753,7 +767,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
@@ -775,7 +789,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
@@ -789,37 +803,36 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castArrayCode(
from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = {
val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx)
-
- val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
+ val arrayClass = classOf[GenericArrayData].getName
val fromElementNull = ctx.freshName("feNull")
val fromElementPrim = ctx.freshName("fePrim")
val toElementNull = ctx.freshName("teNull")
val toElementPrim = ctx.freshName("tePrim")
val size = ctx.freshName("n")
val j = ctx.freshName("j")
- val result = ctx.freshName("result")
+ val values = ctx.freshName("values")
(c, evPrim, evNull) =>
s"""
- final int $size = $c.size();
- final $arraySeqClass $result = new $arraySeqClass($size);
+ final int $size = $c.numElements();
+ final Object[] $values = new Object[$size];
for (int $j = 0; $j < $size; $j ++) {
- if ($c.apply($j) == null) {
- $result.update($j, null);
+ if ($c.isNullAt($j)) {
+ $values[$j] = null;
} else {
boolean $fromElementNull = false;
${ctx.javaType(from.elementType)} $fromElementPrim =
- (${ctx.boxedType(from.elementType)}) $c.apply($j);
+ ${ctx.getValue(c, from.elementType, j)};
${castCode(ctx, fromElementPrim,
fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)}
if ($toElementNull) {
- $result.update($j, null);
+ $values[$j] = null;
} else {
- $result.update($j, $toElementPrim);
+ $values[$j] = $toElementPrim;
}
}
}
- $evPrim = $result;
+ $evPrim = new $arrayClass($values);
"""
}
@@ -891,7 +904,7 @@ case class Cast(child: Expression, dataType: DataType)
$result.setNullAt($i);
} else {
$fromType $fromFieldPrim =
- ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)};
+ ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index cb4c3f24b2721..8fc182607ce68 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -201,11 +201,9 @@ trait Nondeterministic extends Expression {
private[this] var initialized = false
- final def initialize(): Unit = {
- if (!initialized) {
- initInternal()
- initialized = true
- }
+ final def setInitialValues(): Unit = {
+ initInternal()
+ initialized = true
}
protected def initInternal(): Unit
@@ -355,9 +353,9 @@ abstract class BinaryExpression extends Expression {
* @param f accepts two variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
- ctx: CodeGenContext,
- ev: GeneratedExpressionCode,
- f: (String, String) => String): String = {
+ ctx: CodeGenContext,
+ ev: GeneratedExpressionCode,
+ f: (String, String) => String): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"${ev.primitive} = ${f(eval1, eval2)};"
})
@@ -372,9 +370,9 @@ abstract class BinaryExpression extends Expression {
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
- ctx: CodeGenContext,
- ev: GeneratedExpressionCode,
- f: (String, String) => String): String = {
+ ctx: CodeGenContext,
+ ev: GeneratedExpressionCode,
+ f: (String, String) => String): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitive, eval2.primitive)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
new file mode 100644
index 0000000000000..1e74f716955e3
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.rdd.SqlNewHadoopRDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.types.{DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]]
+ */
+case class InputFileName() extends LeafExpression with Nondeterministic {
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = StringType
+
+ override val prettyName = "INPUT_FILE_NAME"
+
+ override protected def initInternal(): Unit = {}
+
+ override protected def evalInternal(input: InternalRow): UTF8String = {
+ SqlNewHadoopRDD.getInputFileName()
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ ev.isNull = "false"
+ s"final ${ctx.javaType(dataType)} ${ev.primitive} = " +
+ "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();"
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
similarity index 95%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index eca36b3274420..291b7a5bc3af5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -15,11 +15,10 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.expressions
+package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types.{LongType, DataType}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 27d6ff587ab71..7c7664e4c1a91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
-import org.apache.spark.sql.types.{StructType, DataType}
+import org.apache.spark.sql.types.{Decimal, StructType, DataType}
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -32,7 +32,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
expressions.foreach(_.foreach {
- case n: Nondeterministic => n.initialize()
+ case n: Nondeterministic => n.setInitialValues()
case _ =>
})
@@ -63,7 +63,7 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
expressions.foreach(_.foreach {
- case n: Nondeterministic => n.initialize()
+ case n: Nondeterministic => n.setInitialValues()
case _ =>
})
@@ -225,6 +225,11 @@ class JoinedRow extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
+ override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
+ if (i < row1.numFields) row1.getDecimal(i, precision, scale)
+ else row2.getDecimal(i - row1.numFields, precision, scale)
+ }
+
override def getStruct(i: Int, numFields: Int): InternalRow = {
if (i < row1.numFields) {
row1.getStruct(i, numFields)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 3f436c0eb893c..9fe877f10fa08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
abstract sealed class SortDirection
case object Ascending extends SortDirection
@@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection)
override def nullable: Boolean = child.nullable
override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
+
+ def isAscending: Boolean = direction == Ascending
+}
+
+/**
+ * An expression to generate a 64-bit long prefix used in sorting.
+ */
+case class SortPrefix(child: SortOrder) extends UnaryExpression {
+
+ override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val childCode = child.child.gen(ctx)
+ val input = childCode.primitive
+ val DoublePrefixCmp = classOf[DoublePrefixComparator].getName
+
+ val (nullValue: Long, prefixCode: String) = child.child.dataType match {
+ case BooleanType =>
+ (Long.MinValue, s"$input ? 1L : 0L")
+ case _: IntegralType =>
+ (Long.MinValue, s"(long) $input")
+ case FloatType | DoubleType =>
+ (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
+ s"$DoublePrefixCmp.computePrefix((double)$input)")
+ case StringType => (0L, s"$input.getPrefix()")
+ case _ => (0L, "0L")
+ }
+
+ childCode.code +
+ s"""
+ |long ${ev.primitive} = ${nullValue}L;
+ |boolean ${ev.isNull} = false;
+ |if (!${childCode.isNull}) {
+ | ${ev.primitive} = $prefixCode;
+ |}
+ """.stripMargin
+ }
+
+ override def dataType: DataType = LongType
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
similarity index 88%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 61ef079d89af5..4b1772a2deed5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -15,11 +15,10 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.expressions
+package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types.{IntegerType, DataType}
@@ -27,7 +26,7 @@ import org.apache.spark.sql.types.{IntegerType, DataType}
/**
* Expression that returns the current partition id of the Spark task.
*/
-private[sql] case object SparkPartitionID extends LeafExpression with Nondeterministic {
+private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override def nullable: Boolean = false
@@ -35,6 +34,8 @@ private[sql] case object SparkPartitionID extends LeafExpression with Nondetermi
@transient private[this] var partitionId: Int = _
+ override val prettyName = "SPARK_PARTITION_ID"
+
override protected def initInternal(): Unit = {
partitionId = TaskContext.getPartitionId()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 10bd19c8a840f..d08f553cefe8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -42,7 +42,7 @@ private[sql] case object Partial extends AggregateMode
private[sql] case object PartialMerge extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers
* containing intermediate results for this function and then generate final result.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the final result of this function is returned.
@@ -50,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode
private[sql] case object Final extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly
+ * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly
* from original input rows without any partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the final result of this function is returned.
@@ -103,9 +103,30 @@ abstract class AggregateFunction2
final override def foldable: Boolean = false
/**
- * The offset of this function's buffer in the underlying buffer shared with other functions.
+ * The offset of this function's start buffer value in the
+ * underlying shared mutable aggregation buffer.
+ * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share
+ * the same aggregation buffer. In this shared buffer, the position of the first
+ * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)`
+ * will be 2.
+ */
+ var mutableBufferOffset: Int = 0
+
+ /**
+ * The offset of this function's start buffer value in the
+ * underlying shared input aggregation buffer. An input aggregation buffer is used
+ * when we merge two aggregation buffers and it is basically the immutable one
+ * (we merge an input aggregation buffer and a mutable aggregation buffer and
+ * then store the new buffer values to the mutable aggregation buffer).
+ * Usually, an input aggregation buffer also contain extra elements like grouping
+ * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often
+ * different.
+ * For example, we have a grouping expression `key``, and two aggregate functions
+ * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first
+ * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
+ * will be 3 (position 0 is used for the value of key`).
*/
- var bufferOffset: Int = 0
+ var inputBufferOffset: Int = 0
/** The schema of the aggregation buffer. */
def bufferSchema: StructType
@@ -176,7 +197,7 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w
override def initialize(buffer: MutableRow): Unit = {
var i = 0
while (i < bufferAttributes.size) {
- buffer(i + bufferOffset) = initialValues(i).eval()
+ buffer(i + mutableBufferOffset) = initialValues(i).eval()
i += 1
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
new file mode 100644
index 0000000000000..4a43318a95490
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+
+/**
+ * Utility functions used by the query planner to convert our plan to new aggregation code path.
+ */
+object Utils {
+ // Right now, we do not support complex types in the grouping key schema.
+ private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
+ val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
+ case array: ArrayType => true
+ case map: MapType => true
+ case struct: StructType => true
+ case _ => false
+ }
+
+ !hasComplexTypes
+ }
+
+ private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+ case p: Aggregate if supportsGroupingKeySchema(p) =>
+ val converted = p.transformExpressionsDown {
+ case expressions.Average(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Average(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Count(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ // We do not support multiple COUNT DISTINCT columns for now.
+ case expressions.CountDistinct(children) if children.length == 1 =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(children.head),
+ mode = aggregate.Complete,
+ isDistinct = true)
+
+ case expressions.First(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.First(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Last(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Last(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Max(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Max(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Min(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Min(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Sum(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.SumDistinct(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = true)
+ }
+ // Check if there is any expressions.AggregateExpression1 left.
+ // If so, we cannot convert this plan.
+ val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
+ // For every expressions, check if it contains AggregateExpression1.
+ expr.find {
+ case agg: expressions.AggregateExpression1 => true
+ case other => false
+ }.isDefined
+ }
+
+ // Check if there are multiple distinct columns.
+ val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg
+ }
+ }.toSet.toSeq
+ val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+ val hasMultipleDistinctColumnSets =
+ if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ true
+ } else {
+ false
+ }
+
+ if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
+
+ case other => None
+ }
+
+ def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
+ // If the plan cannot be converted, we will do a final round check to see if the original
+ // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
+ // we need to throw an exception.
+ val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg.aggregateFunction
+ }
+ }.distinct
+ if (aggregateFunction2s.nonEmpty) {
+ // For functions implemented based on the new interface, prepare a list of function names.
+ val invalidFunctions = {
+ if (aggregateFunction2s.length > 1) {
+ s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
+ s"and ${aggregateFunction2s.head.nodeName} are"
+ } else {
+ s"${aggregateFunction2s.head.nodeName} is"
+ }
+ }
+ val errorMessage =
+ s"${invalidFunctions} implemented based on the new Aggregate Function " +
+ s"interface and it cannot be used with functions implemented based on " +
+ s"the old Aggregate Function interface."
+ throw new AnalysisException(errorMessage)
+ }
+ }
+
+ def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+ case p: Aggregate =>
+ val converted = doConvert(p)
+ if (converted.isDefined) {
+ converted
+ } else {
+ checkInvalidAggregateFunction2(p)
+ None
+ }
+ case other => None
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 42343d4d8d79c..5d4b349b1597a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -404,7 +404,7 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
// partialSum already increase the precision by 10
val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
- val castedCount = Sum(partialCount.toAttribute)
+ val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType)
SplitEvaluation(
Cast(Divide(castedSum, castedCount), dataType),
partialCount :: partialSum :: Nil)
@@ -490,13 +490,13 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1
case DecimalType.Fixed(_, _) =>
val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
- Cast(CombineSum(partialSum.toAttribute), dataType),
+ Cast(Sum(partialSum.toAttribute), dataType),
partialSum :: Nil)
case _ =>
val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
- CombineSum(partialSum.toAttribute),
+ Sum(partialSum.toAttribute),
partialSum :: Nil)
}
}
@@ -522,8 +522,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
private val sum = MutableLiteral(null, calcType)
- private val addFunction =
- Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
+ private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
override def update(input: InternalRow): Unit = {
sum.update(addFunction, input)
@@ -538,67 +537,6 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
}
}
-/**
- * Sum should satisfy 3 cases:
- * 1) sum of all null values = zero
- * 2) sum for table column with no data = null
- * 3) sum of column with null and not null values = sum of not null values
- * Require separate CombineSum Expression and function as it has to distinguish "No data" case
- * versus "data equals null" case, while aggregating results and at each partial expression.i.e.,
- * Combining PartitionLevel InputData
- * <-- null
- * Zero <-- Zero <-- null
- *
- * <-- null <-- no data
- * null <-- null <-- no data
- */
-case class CombineSum(child: Expression) extends AggregateExpression1 {
- def this() = this(null)
-
- override def children: Seq[Expression] = child :: Nil
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"CombineSum($child)"
- override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
-}
-
-case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- private val calcType =
- expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType.bounded(precision + 10, scale)
- case _ =>
- expr.dataType
- }
-
- private val zero = Cast(Literal(0), calcType)
-
- private val sum = MutableLiteral(null, calcType)
-
- private val addFunction =
- Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
-
- override def update(input: InternalRow): Unit = {
- val result = expr.eval(input)
- // partial sum result can be null only when no input rows present
- if(result != null) {
- sum.update(addFunction, input)
- }
- }
-
- override def eval(input: InternalRow): Any = {
- expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- Cast(sum, dataType).eval(null)
- case _ => sum.eval(null)
- }
- }
-}
-
case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
def this() = this(null)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index b37f530ec6814..6f8f4dd230f12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.Interval
+import org.apache.spark.unsafe.types.CalendarInterval
case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
@@ -37,12 +37,12 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
- case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
+ case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}
protected override def nullSafeEval(input: Any): Any = {
- if (dataType.isInstanceOf[IntervalType]) {
- input.asInstanceOf[Interval].negate()
+ if (dataType.isInstanceOf[CalendarIntervalType]) {
+ input.asInstanceOf[CalendarInterval].negate()
} else {
numeric.negate(input)
}
@@ -68,8 +68,7 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
extended = "> SELECT _FUNC_('-1');\n1")
-case class Abs(child: Expression)
- extends UnaryExpression with ExpectsInputTypes with CodegenFallback {
+case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
@@ -122,8 +121,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
- if (dataType.isInstanceOf[IntervalType]) {
- input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval])
+ if (dataType.isInstanceOf[CalendarIntervalType]) {
+ input1.asInstanceOf[CalendarInterval].add(input2.asInstanceOf[CalendarInterval])
} else {
numeric.plus(input1, input2)
}
@@ -135,7 +134,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
- case IntervalType =>
+ case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
@@ -151,8 +150,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
- if (dataType.isInstanceOf[IntervalType]) {
- input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval])
+ if (dataType.isInstanceOf[CalendarIntervalType]) {
+ input1.asInstanceOf[CalendarInterval].subtract(input2.asInstanceOf[CalendarInterval])
} else {
numeric.minus(input1, input2)
}
@@ -164,7 +163,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
- case IntervalType =>
+ case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
index 2087cc7f109bc..c98182c96b165 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
/**
- * An utility class that indents a block of code based on the curly braces.
- *
+ * An utility class that indents a block of code based on the curly braces and parentheses.
* This is used to prettify generated code when in debug mode (or exceptions).
*
* Written by Matei Zaharia.
@@ -35,11 +34,12 @@ private class CodeFormatter {
private var indentString = ""
private def addLine(line: String): Unit = {
- val indentChange = line.count(_ == '{') - line.count(_ == '}')
+ val indentChange =
+ line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0)
val newIndentLevel = math.max(0, indentLevel + indentChange)
// Lines starting with '}' should be de-indented even if they contain '{' after;
// in addition, lines ending with ':' are typically labels
- val thisLineIndent = if (line.startsWith("}") || line.endsWith(":")) {
+ val thisLineIndent = if (line.startsWith("}") || line.startsWith(")") || line.endsWith(":")) {
" " * (indentSize * (indentLevel - 1))
} else {
indentString
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 2f02c90b1d5b3..60e2863f7bbb0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -100,17 +100,19 @@ class CodeGenContext {
}
/**
- * Returns the code to access a column in Row for a given DataType.
+ * Returns the code to access a value in `SpecializedGetters` for a given DataType.
*/
- def getColumn(row: String, dataType: DataType, ordinal: Int): String = {
+ def getValue(getter: String, dataType: DataType, ordinal: String): String = {
val jt = javaType(dataType)
dataType match {
- case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)"
- case StringType => s"$row.getUTF8String($ordinal)"
- case BinaryType => s"$row.getBinary($ordinal)"
- case IntervalType => s"$row.getInterval($ordinal)"
- case t: StructType => s"$row.getStruct($ordinal, ${t.size})"
- case _ => s"($jt)$row.get($ordinal)"
+ case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)"
+ case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})"
+ case StringType => s"$getter.getUTF8String($ordinal)"
+ case BinaryType => s"$getter.getBinary($ordinal)"
+ case CalendarIntervalType => s"$getter.getInterval($ordinal)"
+ case t: StructType => s"$getter.getStruct($ordinal, ${t.size})"
+ case a: ArrayType => s"$getter.getArray($ordinal)"
+ case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter.
}
}
@@ -119,10 +121,10 @@ class CodeGenContext {
*/
def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
val jt = javaType(dataType)
- if (isPrimitiveType(jt)) {
- s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
- } else {
- s"$row.update($ordinal, $value)"
+ dataType match {
+ case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
+ case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
+ case _ => s"$row.update($ordinal, $value)"
}
}
@@ -150,10 +152,10 @@ class CodeGenContext {
case dt: DecimalType => "Decimal"
case BinaryType => "byte[]"
case StringType => "UTF8String"
- case IntervalType => "Interval"
+ case CalendarIntervalType => "CalendarInterval"
case _: StructType => "InternalRow"
- case _: ArrayType => s"scala.collection.Seq"
- case _: MapType => s"scala.collection.Map"
+ case _: ArrayType => "ArrayData"
+ case _: MapType => "scala.collection.Map"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"
@@ -214,7 +216,9 @@ class CodeGenContext {
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case NullType => "0"
- case other => s"$c1.compare($c2)"
+ case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
+ case _ => throw new IllegalArgumentException(
+ "cannot generate compare code for un-comparable type")
}
/**
@@ -293,7 +297,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[UnsafeRow].getName,
classOf[UTF8String].getName,
classOf[Decimal].getName,
- classOf[Interval].getName
+ classOf[CalendarInterval].getName,
+ classOf[ArrayData].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index 6b187f05604fd..3492d2c6189ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression}
/**
* A trait that can be used to provide a fallback mode for expression code generation.
@@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.Expression
trait CodegenFallback extends Expression {
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ foreach {
+ case n: Nondeterministic => n.setInitialValues()
+ case _ =>
+ }
+
ctx.references += this
val objectTerm = ctx.freshName("obj")
s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 9d2161947b351..1d223986d9441 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -34,15 +34,74 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
+ private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName
+ private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName
+ private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName
/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
case t: AtomicType if !t.isInstanceOf[DecimalType] => true
- case _: IntervalType => true
+ case _: CalendarIntervalType => true
+ case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case NullType => true
+ case t: DecimalType => true
case _ => false
}
+ def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
+ case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
+ s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))"
+ case StringType =>
+ s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
+ case BinaryType =>
+ s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
+ case CalendarIntervalType =>
+ s" + (${ev.isNull} ? 0 : 16)"
+ case _: StructType =>
+ s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
+ case _ => ""
+ }
+
+ def genFieldWriter(
+ ctx: CodeGenContext,
+ fieldType: DataType,
+ ev: GeneratedExpressionCode,
+ primitive: String,
+ index: Int,
+ cursor: String): String = fieldType match {
+ case _ if ctx.isPrimitiveType(fieldType) =>
+ s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}"
+ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+ s"""
+ // make sure Decimal object has the same scale as DecimalType
+ if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
+ $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
+ } else {
+ $primitive.setNullAt($index);
+ }
+ """
+ case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
+ s"""
+ // make sure Decimal object has the same scale as DecimalType
+ if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
+ $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
+ } else {
+ $primitive.setNullAt($index);
+ }
+ """
+ case StringType =>
+ s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ case BinaryType =>
+ s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ case CalendarIntervalType =>
+ s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ case t: StructType =>
+ s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ case NullType => ""
+ case _ =>
+ throw new UnsupportedOperationException(s"Not supported DataType: $fieldType")
+ }
+
/**
* Generates the code to create an [[UnsafeRow]] object based on the input expressions.
* @param ctx context for code generation
@@ -55,41 +114,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val ret = ev.primitive
ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();")
- val bufferTerm = ctx.freshName("buffer")
- ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];")
- val cursorTerm = ctx.freshName("cursor")
- val numBytesTerm = ctx.freshName("numBytes")
+ val buffer = ctx.freshName("buffer")
+ ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
+ val cursor = ctx.freshName("cursor")
+ val numBytes = ctx.freshName("numBytes")
- val exprs = expressions.map(_.gen(ctx))
+ val exprs = expressions.map { e => e.dataType match {
+ case st: StructType => createCodeForStruct(ctx, e.gen(ctx), st)
+ case _ => e.gen(ctx)
+ }}
val allExprs = exprs.map(_.code).mkString("\n")
- val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
- val additionalSize = expressions.zipWithIndex.map { case (e, i) =>
- e.dataType match {
- case StringType =>
- s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))"
- case BinaryType =>
- s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
- case IntervalType =>
- s" + (${exprs(i).isNull} ? 0 : 16)"
- case _ => ""
- }
+ val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
+ val additionalSize = expressions.zipWithIndex.map {
+ case (e, i) => genAdditionalSize(e.dataType, exprs(i))
}.mkString("")
val writers = expressions.zipWithIndex.map { case (e, i) =>
- val update = e.dataType match {
- case dt if ctx.isPrimitiveType(dt) =>
- s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}"
- case StringType =>
- s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
- case BinaryType =>
- s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
- case IntervalType =>
- s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
- case NullType => ""
- case _ =>
- throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
- }
+ val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor)
s"""if (${exprs(i).isNull}) {
$ret.setNullAt($i);
} else {
@@ -99,24 +141,115 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
$allExprs
- int $numBytesTerm = $fixedSize $additionalSize;
- if ($numBytesTerm > $bufferTerm.length) {
- $bufferTerm = new byte[$numBytesTerm];
+ int $numBytes = $fixedSize $additionalSize;
+ if ($numBytes > $buffer.length) {
+ $buffer = new byte[$numBytes];
}
$ret.pointTo(
- $bufferTerm,
+ $buffer,
org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
${expressions.size},
- $numBytesTerm);
- int $cursorTerm = $fixedSize;
-
+ $numBytes);
+ int $cursor = $fixedSize;
$writers
boolean ${ev.isNull} = false;
"""
}
+ /**
+ * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow.
+ *
+ * This function also handles nested structs by recursively generating the code to do conversion.
+ *
+ * @param ctx code generation context
+ * @param input the input struct, identified by a [[GeneratedExpressionCode]]
+ * @param schema schema of the struct field
+ */
+ // TODO: refactor createCode and this function to reduce code duplication.
+ private def createCodeForStruct(
+ ctx: CodeGenContext,
+ input: GeneratedExpressionCode,
+ schema: StructType): GeneratedExpressionCode = {
+
+ val isNull = input.isNull
+ val primitive = ctx.freshName("structConvert")
+ ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();")
+ val buffer = ctx.freshName("buffer")
+ ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
+ val cursor = ctx.freshName("cursor")
+
+ val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map {
+ case (dt, i) => dt match {
+ case st: StructType =>
+ val nestedStructEv = GeneratedExpressionCode(
+ code = "",
+ isNull = s"${input.primitive}.isNullAt($i)",
+ primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
+ )
+ createCodeForStruct(ctx, nestedStructEv, st)
+ case _ =>
+ GeneratedExpressionCode(
+ code = "",
+ isNull = s"${input.primitive}.isNullAt($i)",
+ primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
+ )
+ }
+ }
+ val allExprs = exprs.map(_.code).mkString("\n")
+
+ val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
+ val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) =>
+ genAdditionalSize(dt, ev)
+ }.mkString("")
+
+ val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) =>
+ val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor)
+ s"""
+ if (${exprs(i).isNull}) {
+ $primitive.setNullAt($i);
+ } else {
+ $update;
+ }
+ """
+ }.mkString("\n ")
+
+ // Note that we add a shortcut here for performance: if the input is already an UnsafeRow,
+ // just copy the bytes directly into our buffer space without running any conversion.
+ // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from
+ // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow.
+ val tmp = ctx.freshName("tmp")
+ val numBytes = ctx.freshName("numBytes")
+ val code = s"""
+ |${input.code}
+ |if (!${input.isNull}) {
+ | Object $tmp = (Object) ${input.primitive};
+ | if ($tmp instanceof UnsafeRow) {
+ | $primitive = (UnsafeRow) $tmp;
+ | } else {
+ | $allExprs
+ |
+ | int $numBytes = $fixedSize $additionalSize;
+ | if ($numBytes > $buffer.length) {
+ | $buffer = new byte[$numBytes];
+ | }
+ |
+ | $primitive.pointTo(
+ | $buffer,
+ | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
+ | ${exprs.size},
+ | $numBytes);
+ | int $cursor = $fixedSize;
+ |
+ | $writers
+ | }
+ |}
+ """.stripMargin
+
+ GeneratedExpressionCode(code, isNull, primitive)
+ }
+
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
in.map(ExpressionCanonicalizer.execute)
@@ -132,18 +265,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
eval.code = createCode(ctx, eval, expressions)
val code = s"""
- private $exprType[] expressions;
-
- public Object generate($exprType[] expr) {
- this.expressions = expr;
- return new SpecificProjection();
+ public Object generate($exprType[] exprs) {
+ return new SpecificProjection(exprs);
}
class SpecificProjection extends ${classOf[UnsafeProjection].getName} {
+ private $exprType[] expressions;
+
${declareMutableStates(ctx)}
- public SpecificProjection() {
+ public SpecificProjection($exprType[] expressions) {
+ this.expressions = expressions;
${initMutableStates(ctx)}
}
@@ -159,7 +292,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
"""
- logDebug(s"code for ${expressions.mkString(",")}:\n$code")
+ logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
val c = compile(code)
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 2d92dcf23a86e..1a00dbc254de1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,11 +27,15 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
override def nullSafeEval(value: Any): Int = child.dataType match {
- case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size
- case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size
+ case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
+ case _: MapType => value.asInstanceOf[Map[Any, Any]].size
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();")
+ val sizeCall = child.dataType match {
+ case _: ArrayType => "numElements()"
+ case _: MapType => "size()"
+ }
+ nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 119168fa59f15..a145dfb4bbf08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
-import scala.collection.mutable
-
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -44,25 +43,26 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = {
- children.map(_.eval(input))
+ new GenericArrayData(children.map(_.eval(input)).toArray)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
+ val arrayClass = classOf[GenericArrayData].getName
s"""
- boolean ${ev.isNull} = false;
- $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size});
+ final boolean ${ev.isNull} = false;
+ final Object[] values = new Object[${children.size}];
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
- ${ev.primitive}.update($i, null);
+ values[$i] = null;
} else {
- ${ev.primitive}.update($i, ${eval.primitive});
+ values[$i] = ${eval.primitive};
}
"""
- }.mkString("\n")
+ }.mkString("\n") +
+ s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);"
}
override def prettyName: String = "array"
@@ -104,18 +104,19 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
- if (${eval.isNull}) {
- ${ev.primitive}.update($i, null);
- } else {
- ${ev.primitive}.update($i, ${eval.primitive});
- }
- """
+ if (${eval.isNull}) {
+ ${ev.primitive}.update($i, null);
+ } else {
+ ${ev.primitive}.update($i, ${eval.primitive});
+ }
+ """
}.mkString("\n")
}
override def prettyName: String = "struct"
}
+
/**
* Creates a struct with the given field names and values
*
@@ -126,11 +127,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
private lazy val (nameExprs, valExprs) =
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
- private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
+ private lazy val names = nameExprs.map(_.eval(EmptyRow))
override lazy val dataType: StructType = {
val fields = names.zip(valExprs).map { case (name, valExpr) =>
- StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
+ StructField(name.asInstanceOf[UTF8String].toString,
+ valExpr.dataType, valExpr.nullable, Metadata.empty)
}
StructType(fields)
}
@@ -143,14 +145,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
} else {
- val invalidNames =
- nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable)
+ val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
if (invalidNames.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
- s"Odd position only allow foldable and not-null StringType expressions, got :" +
+ s"Only foldable StringType expressions are allowed to appear at odd position , got :" +
s" ${invalidNames.mkString(",")}")
- } else {
+ } else if (names.forall(_ != null)){
TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure("Field name should not be null")
}
}
}
@@ -168,14 +171,83 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
valExprs.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
- if (${eval.isNull}) {
- ${ev.primitive}.update($i, null);
- } else {
- ${ev.primitive}.update($i, ${eval.primitive});
- }
- """
+ if (${eval.isNull}) {
+ ${ev.primitive}.update($i, null);
+ } else {
+ ${ev.primitive}.update($i, ${eval.primitive});
+ }
+ """
}.mkString("\n")
}
override def prettyName: String = "named_struct"
}
+
+/**
+ * Returns a Row containing the evaluation of all children expressions. This is a variant that
+ * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
+ * this expression automatically at runtime.
+ */
+case class CreateStructUnsafe(children: Seq[Expression]) extends Expression {
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override lazy val resolved: Boolean = childrenResolved
+
+ override lazy val dataType: StructType = {
+ val fields = children.zipWithIndex.map { case (child, idx) =>
+ child match {
+ case ne: NamedExpression =>
+ StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
+ case _ =>
+ StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
+ }
+ }
+ StructType(fields)
+ }
+
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ GenerateUnsafeProjection.createCode(ctx, ev, children)
+ }
+
+ override def prettyName: String = "struct_unsafe"
+}
+
+
+/**
+ * Creates a struct with the given field names and values. This is a variant that returns
+ * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
+ * this expression automatically at runtime.
+ *
+ * @param children Seq(name1, val1, name2, val2, ...)
+ */
+case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression {
+
+ private lazy val (nameExprs, valExprs) =
+ children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
+
+ private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
+
+ override lazy val dataType: StructType = {
+ val fields = names.zip(valExprs).map { case (name, valExpr) =>
+ StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
+ }
+ StructType(fields)
+ }
+
+ override def foldable: Boolean = valExprs.forall(_.foldable)
+
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ GenerateUnsafeProjection.createCode(ctx, ev, valExprs)
+ }
+
+ override def prettyName: String = "named_struct_unsafe"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 6331a9eb603ca..99393c9c76ab6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -57,7 +57,8 @@ object ExtractValue {
case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
- GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)
+ GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
+ ordinal, fields.length, containsNull)
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)
@@ -118,7 +119,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
- ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)};
+ ${ev.primitive} = ${ctx.getValue(eval, dataType, ordinal.toString)};
}
"""
})
@@ -134,6 +135,7 @@ case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
+ numFields: Int,
containsNull: Boolean) extends UnaryExpression {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
@@ -141,26 +143,45 @@ case class GetArrayStructFields(
override def toString: String = s"$child.${field.name}"
protected override def nullSafeEval(input: Any): Any = {
- input.asInstanceOf[Seq[InternalRow]].map { row =>
- if (row == null) null else row.get(ordinal, field.dataType)
+ val array = input.asInstanceOf[ArrayData]
+ val length = array.numElements()
+ val result = new Array[Any](length)
+ var i = 0
+ while (i < length) {
+ if (array.isNullAt(i)) {
+ result(i) = null
+ } else {
+ val row = array.getStruct(i, numFields)
+ if (row.isNullAt(ordinal)) {
+ result(i) = null
+ } else {
+ result(i) = row.get(ordinal, field.dataType)
+ }
+ }
+ i += 1
}
+ new GenericArrayData(result)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val arraySeqClass = "scala.collection.mutable.ArraySeq"
- // TODO: consider using Array[_] for ArrayType child to avoid
- // boxing of primitives
+ val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, eval => {
s"""
- final int n = $eval.size();
- final $arraySeqClass values = new $arraySeqClass(n);
+ final int n = $eval.numElements();
+ final Object[] values = new Object[n];
for (int j = 0; j < n; j++) {
- InternalRow row = (InternalRow) $eval.apply(j);
- if (row != null && !row.isNullAt($ordinal)) {
- values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)});
+ if ($eval.isNullAt(j)) {
+ values[j] = null;
+ } else {
+ final InternalRow row = $eval.getStruct(j, $numFields);
+ if (row.isNullAt($ordinal)) {
+ values[j] = null;
+ } else {
+ values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
+ }
}
}
- ${ev.primitive} = (${ctx.javaType(dataType)}) values;
+ ${ev.primitive} = new $arrayClass(values);
"""
})
}
@@ -186,23 +207,23 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
- val baseValue = value.asInstanceOf[Seq[_]]
+ val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
- if (index >= baseValue.size || index < 0) {
+ if (index >= baseValue.numElements() || index < 0) {
null
} else {
- baseValue(index)
+ baseValue.get(index)
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
- final int index = (int)$eval2;
- if (index >= $eval1.size() || index < 0) {
+ final int index = (int) $eval2;
+ if (index >= $eval1.numElements() || index < 0) {
${ev.isNull} = true;
} else {
- ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index);
+ ${ev.primitive} = ${ctx.getValue(eval1, dataType, "index")};
}
"""
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index 15b33da884dcb..961b1d8616801 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -315,7 +315,6 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
* It takes at least 2 parameters, and returns null iff all parameters are null.
*/
case class Least(children: Seq[Expression]) extends Expression {
- require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length)
override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
@@ -323,7 +322,9 @@ case class Least(children: Seq[Expression]) extends Expression {
private lazy val ordering = TypeUtils.getOrdering(dataType)
override def checkInputDataTypes(): TypeCheckResult = {
- if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+ if (children.length <= 1) {
+ TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments")
+ } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got LEAST (${children.map(_.dataType)}).")
@@ -369,7 +370,6 @@ case class Least(children: Seq[Expression]) extends Expression {
* It takes at least 2 parameters, and returns null iff all parameters are null.
*/
case class Greatest(children: Seq[Expression]) extends Expression {
- require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length)
override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
@@ -377,7 +377,9 @@ case class Greatest(children: Seq[Expression]) extends Expression {
private lazy val ordering = TypeUtils.getOrdering(dataType)
override def checkInputDataTypes(): TypeCheckResult = {
- if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+ if (children.length <= 1) {
+ TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments")
+ } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got GREATEST (${children.map(_.dataType)}).")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
index 9e55f0546e123..6e7613340c032 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import java.sql.Date
import java.text.SimpleDateFormat
import java.util.{Calendar, TimeZone}
@@ -26,7 +25,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+
+import scala.util.Try
/**
* Returns the current date at the start of query evaluation.
@@ -62,6 +63,53 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
}
}
+/**
+ * Adds a number of days to startdate.
+ */
+case class DateAdd(startDate: Expression, days: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def left: Expression = startDate
+ override def right: Expression = days
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
+
+ override def dataType: DataType = DateType
+
+ override def nullSafeEval(start: Any, d: Any): Any = {
+ start.asInstanceOf[Int] + d.asInstanceOf[Int]
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (sd, d) => {
+ s"""${ev.primitive} = $sd + $d;"""
+ })
+ }
+}
+
+/**
+ * Subtracts a number of days to startdate.
+ */
+case class DateSub(startDate: Expression, days: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+ override def left: Expression = startDate
+ override def right: Expression = days
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
+
+ override def dataType: DataType = DateType
+
+ override def nullSafeEval(start: Any, d: Any): Any = {
+ start.asInstanceOf[Int] - d.asInstanceOf[Int]
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (sd, d) => {
+ s"""${ev.primitive} = $sd - $d;"""
+ })
+ }
+}
+
case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
@@ -74,9 +122,7 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, (c) =>
- s"""$dtu.getHours($c)"""
- )
+ defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)")
}
}
@@ -92,9 +138,7 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, (c) =>
- s"""$dtu.getMinutes($c)"""
- )
+ defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)")
}
}
@@ -110,9 +154,7 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, (c) =>
- s"""$dtu.getSeconds($c)"""
- )
+ defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)")
}
}
@@ -128,9 +170,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, (c) =>
- s"""$dtu.getDayInYear($c)"""
- )
+ defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)")
}
}
@@ -147,9 +187,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, c =>
- s"""$dtu.getYear($c)"""
- )
+ defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)")
}
}
@@ -165,9 +203,7 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, (c) =>
- s"""$dtu.getQuarter($c)"""
- )
+ defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)")
}
}
@@ -183,9 +219,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, (c) =>
- s"""$dtu.getMonth($c)"""
- )
+ defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)")
}
}
@@ -201,9 +235,7 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, (c) =>
- s"""$dtu.getDayOfMonth($c)"""
- )
+ defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)")
}
}
@@ -226,7 +258,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, (time) => {
+ nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
val c = ctx.freshName("cal")
ctx.addMutableState(cal, c,
@@ -250,18 +282,503 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
- override def prettyName: String = "date_format"
-
override protected def nullSafeEval(timestamp: Any, format: Any): Any = {
val sdf = new SimpleDateFormat(format.toString)
- UTF8String.fromString(sdf.format(new Date(timestamp.asInstanceOf[Long] / 1000)))
+ UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000)))
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val sdf = classOf[SimpleDateFormat].getName
defineCodeGen(ctx, ev, (timestamp, format) => {
s"""UTF8String.fromString((new $sdf($format.toString()))
- .format(new java.sql.Date($timestamp / 1000)))"""
+ .format(new java.util.Date($timestamp / 1000)))"""
+ })
+ }
+
+ override def prettyName: String = "date_format"
+}
+
+/**
+ * Converts time string with given pattern
+ * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html])
+ * to Unix time stamp (in seconds), returns null if fail.
+ * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null.
+ * If the second parameter is missing, use "yyyy-MM-dd HH:mm:ss".
+ * If no parameters provided, the first parameter will be current_timestamp.
+ * If the first parameter is a Date or Timestamp instead of String, we will ignore the
+ * second parameter.
+ */
+case class UnixTimestamp(timeExp: Expression, format: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def left: Expression = timeExp
+ override def right: Expression = format
+
+ def this(time: Expression) = {
+ this(time, Literal("yyyy-MM-dd HH:mm:ss"))
+ }
+
+ def this() = {
+ this(CurrentTimestamp())
+ }
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(StringType, DateType, TimestampType), StringType)
+
+ override def dataType: DataType = LongType
+
+ private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String]
+
+ override def eval(input: InternalRow): Any = {
+ val t = left.eval(input)
+ if (t == null) {
+ null
+ } else {
+ left.dataType match {
+ case DateType =>
+ DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L
+ case TimestampType =>
+ t.asInstanceOf[Long] / 1000000L
+ case StringType if right.foldable =>
+ if (constFormat != null) {
+ Try(new SimpleDateFormat(constFormat.toString).parse(
+ t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null)
+ } else {
+ null
+ }
+ case StringType =>
+ val f = format.eval(input)
+ if (f == null) {
+ null
+ } else {
+ val formatString = f.asInstanceOf[UTF8String].toString
+ Try(new SimpleDateFormat(formatString).parse(
+ t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null)
+ }
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ left.dataType match {
+ case StringType if right.foldable =>
+ val sdf = classOf[SimpleDateFormat].getName
+ val fString = if (constFormat == null) null else constFormat.toString
+ val formatter = ctx.freshName("formatter")
+ if (fString == null) {
+ s"""
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ """
+ } else {
+ val eval1 = left.gen(ctx)
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = ${eval1.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ try {
+ $sdf $formatter = new $sdf("$fString");
+ ${ev.primitive} =
+ $formatter.parse(${eval1.primitive}.toString()).getTime() / 1000L;
+ } catch (java.lang.Throwable e) {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ }
+ case StringType =>
+ val sdf = classOf[SimpleDateFormat].getName
+ nullSafeCodeGen(ctx, ev, (string, format) => {
+ s"""
+ try {
+ ${ev.primitive} =
+ (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L;
+ } catch (java.lang.Throwable e) {
+ ${ev.isNull} = true;
+ }
+ """
+ })
+ case TimestampType =>
+ val eval1 = left.gen(ctx)
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = ${eval1.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.primitive} = ${eval1.primitive} / 1000000L;
+ }
+ """
+ case DateType =>
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ val eval1 = left.gen(ctx)
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = ${eval1.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.primitive} = $dtu.daysToMillis(${eval1.primitive}) / 1000L;
+ }
+ """
+ }
+ }
+}
+
+/**
+ * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
+ * representing the timestamp of that moment in the current system time zone in the given
+ * format. If the format is missing, using format like "1970-01-01 00:00:00".
+ * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null.
+ */
+case class FromUnixTime(sec: Expression, format: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def left: Expression = sec
+ override def right: Expression = format
+
+ def this(unix: Expression) = {
+ this(unix, Literal("yyyy-MM-dd HH:mm:ss"))
+ }
+
+ override def dataType: DataType = StringType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType)
+
+ private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String]
+
+ override def eval(input: InternalRow): Any = {
+ val time = left.eval(input)
+ if (time == null) {
+ null
+ } else {
+ if (format.foldable) {
+ if (constFormat == null) {
+ null
+ } else {
+ Try(UTF8String.fromString(new SimpleDateFormat(constFormat.toString).format(
+ new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null)
+ }
+ } else {
+ val f = format.eval(input)
+ if (f == null) {
+ null
+ } else {
+ Try(UTF8String.fromString(new SimpleDateFormat(
+ f.asInstanceOf[UTF8String].toString).format(new java.util.Date(
+ time.asInstanceOf[Long] * 1000L)))).getOrElse(null)
+ }
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val sdf = classOf[SimpleDateFormat].getName
+ if (format.foldable) {
+ if (constFormat == null) {
+ s"""
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ """
+ } else {
+ val t = left.gen(ctx)
+ s"""
+ ${t.code}
+ boolean ${ev.isNull} = ${t.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ try {
+ ${ev.primitive} = UTF8String.fromString(new $sdf("${constFormat.toString}").format(
+ new java.util.Date(${t.primitive} * 1000L)));
+ } catch (java.lang.Throwable e) {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ }
+ } else {
+ nullSafeCodeGen(ctx, ev, (seconds, f) => {
+ s"""
+ try {
+ ${ev.primitive} = UTF8String.fromString((new $sdf($f.toString())).format(
+ new java.util.Date($seconds * 1000L)));
+ } catch (java.lang.Throwable e) {
+ ${ev.isNull} = true;
+ }""".stripMargin
+ })
+ }
+ }
+}
+
+/**
+ * Returns the last day of the month which the date belongs to.
+ */
+case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+ override def child: Expression = startDate
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
+
+ override def dataType: DataType = DateType
+
+ override def nullSafeEval(date: Any): Any = {
+ DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int])
+ }
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)")
+ }
+
+ override def prettyName: String = "last_day"
+}
+
+/**
+ * Returns the first date which is later than startDate and named as dayOfWeek.
+ * For example, NextDay(2015-07-27, Sunday) would return 2015-08-02, which is the first
+ * Sunday later than 2015-07-27.
+ *
+ * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]].
+ */
+case class NextDay(startDate: Expression, dayOfWeek: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def left: Expression = startDate
+ override def right: Expression = dayOfWeek
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
+
+ override def dataType: DataType = DateType
+
+ override def nullSafeEval(start: Any, dayOfW: Any): Any = {
+ val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String])
+ if (dow == -1) {
+ null
+ } else {
+ val sd = start.asInstanceOf[Int]
+ DateTimeUtils.getNextDateForDayOfWeek(sd, dow)
+ }
+ }
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (sd, dowS) => {
+ val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$")
+ val dayOfWeekTerm = ctx.freshName("dayOfWeek")
+ if (dayOfWeek.foldable) {
+ val input = dayOfWeek.eval().asInstanceOf[UTF8String]
+ if ((input eq null) || DateTimeUtils.getDayOfWeekFromString(input) == -1) {
+ s"""
+ |${ev.isNull} = true;
+ """.stripMargin
+ } else {
+ val dayOfWeekValue = DateTimeUtils.getDayOfWeekFromString(input)
+ s"""
+ |${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekValue);
+ """.stripMargin
+ }
+ } else {
+ s"""
+ |int $dayOfWeekTerm = $dateTimeUtilClass.getDayOfWeekFromString($dowS);
+ |if ($dayOfWeekTerm == -1) {
+ | ${ev.isNull} = true;
+ |} else {
+ | ${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekTerm);
+ |}
+ """.stripMargin
+ }
})
}
+
+ override def prettyName: String = "next_day"
+}
+
+/**
+ * Adds an interval to timestamp.
+ */
+case class TimeAdd(start: Expression, interval: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def left: Expression = start
+ override def right: Expression = interval
+
+ override def toString: String = s"$left + $right"
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)
+
+ override def dataType: DataType = TimestampType
+
+ override def nullSafeEval(start: Any, interval: Any): Any = {
+ val itvl = interval.asInstanceOf[CalendarInterval]
+ DateTimeUtils.timestampAddInterval(
+ start.asInstanceOf[Long], itvl.months, itvl.microseconds)
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ defineCodeGen(ctx, ev, (sd, i) => {
+ s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)"""
+ })
+ }
+}
+
+/**
+ * Subtracts an interval from timestamp.
+ */
+case class TimeSub(start: Expression, interval: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def left: Expression = start
+ override def right: Expression = interval
+
+ override def toString: String = s"$left - $right"
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)
+
+ override def dataType: DataType = TimestampType
+
+ override def nullSafeEval(start: Any, interval: Any): Any = {
+ val itvl = interval.asInstanceOf[CalendarInterval]
+ DateTimeUtils.timestampAddInterval(
+ start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ defineCodeGen(ctx, ev, (sd, i) => {
+ s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
+ })
+ }
+}
+
+/**
+ * Returns the date that is num_months after start_date.
+ */
+case class AddMonths(startDate: Expression, numMonths: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def left: Expression = startDate
+ override def right: Expression = numMonths
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
+
+ override def dataType: DataType = DateType
+
+ override def nullSafeEval(start: Any, months: Any): Any = {
+ DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int])
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ defineCodeGen(ctx, ev, (sd, m) => {
+ s"""$dtu.dateAddMonths($sd, $m)"""
+ })
+ }
+}
+
+/**
+ * Returns number of months between dates date1 and date2.
+ */
+case class MonthsBetween(date1: Expression, date2: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def left: Expression = date1
+ override def right: Expression = date2
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType)
+
+ override def dataType: DataType = DoubleType
+
+ override def nullSafeEval(t1: Any, t2: Any): Any = {
+ DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long])
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ defineCodeGen(ctx, ev, (l, r) => {
+ s"""$dtu.monthsBetween($l, $r)"""
+ })
+ }
+}
+
+/**
+ * Returns the date part of a timestamp or string.
+ */
+case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+ // Implicit casting of spark will accept string in both date and timestamp format, as
+ // well as TimestampType.
+ override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
+
+ override def dataType: DataType = DateType
+
+ override def eval(input: InternalRow): Any = child.eval(input)
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, d => d)
+ }
+}
+
+/*
+ * Returns date truncated to the unit specified by the format.
+ */
+case class TruncDate(date: Expression, format: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+ override def left: Expression = date
+ override def right: Expression = format
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
+ override def dataType: DataType = DateType
+ override def prettyName: String = "trunc"
+
+ lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
+
+ override def eval(input: InternalRow): Any = {
+ val minItem = if (format.foldable) {
+ minItemConst
+ } else {
+ DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
+ }
+ if (minItem == -1) {
+ // unknown format
+ null
+ } else {
+ val d = date.eval(input)
+ if (d == null) {
+ null
+ } else {
+ DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem)
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+
+ if (format.foldable) {
+ if (minItemConst == -1) {
+ s"""
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ """
+ } else {
+ val d = date.gen(ctx)
+ s"""
+ ${d.code}
+ boolean ${ev.isNull} = ${d.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst);
+ }
+ """
+ }
+ } else {
+ nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
+ val form = ctx.freshName("form")
+ s"""
+ int $form = $dtu.parseTruncLevel($fmt);
+ if ($form == -1) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.primitive} = $dtu.truncDate($dateVal, $form);
+ }
+ """
+ })
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 2dbcf2830f876..8064235c64ef9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -121,8 +121,8 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
child.dataType match {
case ArrayType(_, _) =>
- val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
- if (inputArray == null) Nil else inputArray.map(v => InternalRow(v))
+ val inputArray = child.eval(input).asInstanceOf[ArrayData]
+ if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v))
case MapType(_, _, _) =>
val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]]
if (inputMap == null) Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 85060b7893556..34bad23802ba4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -42,7 +42,7 @@ object Literal {
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case a: Array[Byte] => Literal(a, BinaryType)
- case i: Interval => Literal(i, IntervalType)
+ case i: CalendarInterval => Literal(i, CalendarIntervalType)
case null => Literal(null, NullType)
case _ =>
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
@@ -118,7 +118,7 @@ case class Literal protected (value: Any, dataType: DataType)
super.genCode(ctx, ev)
} else {
ev.isNull = "false"
- ev.primitive = s"${value}"
+ ev.primitive = s"${value}D"
""
}
case ByteType | ShortType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 68cca0ad3d067..e6d807f6d897b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -646,19 +646,19 @@ case class Logarithm(left: Expression, right: Expression)
/**
* Round the `child`'s result to `scale` decimal place when `scale` >= 0
* or round at integral part when `scale` < 0.
- * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
+ * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30.
*
- * Child of IntegralType would eval to itself when `scale` >= 0.
- * Child of FractionalType whose value is NaN or Infinite would always eval to itself.
+ * Child of IntegralType would round to itself when `scale` >= 0.
+ * Child of FractionalType whose value is NaN or Infinite would always round to itself.
*
- * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]],
- * which leads to scale update in DecimalType's [[PrecisionInfo]]
+ * Round's dataType would always equal to `child`'s dataType except for DecimalType,
+ * which would lead scale decrease from the origin DecimalType.
*
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
*/
case class Round(child: Expression, scale: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
import BigDecimal.RoundingMode.HALF_UP
@@ -838,6 +838,4 @@ case class Round(child: Expression, scale: Expression)
"""
}
}
-
- override def prettyName: String = "round"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 5bfe1cad24a3e..ab7d3afce8f2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -31,7 +31,7 @@ object InterpretedPredicate {
def create(expression: Expression): (InternalRow => Boolean) = {
expression.foreach {
- case n: Nondeterministic => n.initialize()
+ case n: Nondeterministic => n.setInitialValues()
case _ =>
}
(r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
index 8f30519697a37..62d3d204ca872 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -66,7 +66,7 @@ case class Rand(seed: Long) extends RDG {
val rngTerm = ctx.freshName("rng")
val className = classOf[XORShiftRandom].getName
ctx.addMutableState(className, rngTerm,
- s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());")
+ s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
ev.isNull = "false"
s"""
final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble();
@@ -89,7 +89,7 @@ case class Randn(seed: Long) extends RDG {
val rngTerm = ctx.freshName("rng")
val className = classOf[XORShiftRandom].getName
ctx.addMutableState(className, rngTerm,
- s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());")
+ s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
ev.isNull = "false"
s"""
final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian();
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index b7c4ece4a16fe..df6ea586c87ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.{DataType, StructType, AtomicType}
+import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType}
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -39,6 +39,7 @@ abstract class MutableRow extends InternalRow {
def setShort(i: Int, value: Short): Unit = { update(i, value) }
def setByte(i: Int, value: Byte): Unit = { update(i, value) }
def setFloat(i: Int, value: Float): Unit = { update(i, value) }
+ def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) }
def setString(i: Int, value: String): Unit = {
update(i, UTF8String.fromString(value))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 38b0fb37dee3b..79c0ca56a8e79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -22,7 +22,6 @@ import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -52,7 +51,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val evals = children.map(_.gen(ctx))
val inputs = evals.map { eval =>
- s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
+ s"${eval.isNull} ? null : ${eval.primitive}"
}.mkString(", ")
evals.map(_.code).mkString("\n") + s"""
boolean ${ev.isNull} = false;
@@ -93,7 +92,7 @@ case class ConcatWs(children: Seq[Expression])
val flatInputs = children.flatMap { child =>
child.eval(input) match {
case s: UTF8String => Iterator(s)
- case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]]
+ case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String])
case null => Iterator(null.asInstanceOf[UTF8String])
}
}
@@ -106,7 +105,7 @@ case class ConcatWs(children: Seq[Expression])
val evals = children.map(_.gen(ctx))
val inputs = evals.map { eval =>
- s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
+ s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}"
}.mkString(", ")
evals.map(_.code).mkString("\n") + s"""
@@ -666,13 +665,15 @@ case class StringSplit(str: Expression, pattern: Expression)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def nullSafeEval(string: Any, regex: Any): Any = {
- string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq
+ val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1)
+ new GenericArrayData(strings.asInstanceOf[Array[Any]])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, pattern) =>
- s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer(
- java.util.Arrays.asList($str.split($pattern, -1)));""")
+ // Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
+ s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""")
}
override def prettyName: String = "split"
@@ -777,7 +778,6 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def dataType: DataType = IntegerType
-
protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any =
leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String])
@@ -1009,7 +1009,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
s"""
${evalSubject.code}
- boolean ${ev.isNull} = ${evalSubject.isNull};
+ boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${evalSubject.isNull}) {
${evalRegexp.code}
@@ -1104,9 +1104,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
val evalIdx = idx.gen(ctx)
s"""
- ${ctx.javaType(dataType)} ${ev.primitive} = null;
- boolean ${ev.isNull} = true;
${evalSubject.code}
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ boolean ${ev.isNull} = true;
if (!${evalSubject.isNull}) {
${evalRegexp.code}
if (!${evalRegexp.isNull}) {
@@ -1118,7 +1118,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
${classOf[java.util.regex.Matcher].getCanonicalName} m =
- ${termPattern}.matcher(${evalSubject.primitive}.toString());
+ ${termPattern}.matcher(${evalSubject.primitive}.toString());
if (m.find()) {
${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult();
${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive}));
@@ -1140,7 +1140,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
* fractional part.
*/
case class FormatNumber(x: Expression, d: Expression)
- extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+ extends BinaryExpression with ExpectsInputTypes {
override def left: Expression = x
override def right: Expression = d
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 813c62009666c..29d706dcb39a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -312,7 +312,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
- case e @ GetArrayStructFields(Literal(null, _), _, _, _) => Literal.create(null, e.dataType)
+ case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
+ Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index af68358daf5f1..a67f8de6b733a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@@ -33,7 +34,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
}.nonEmpty
)
- !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
+ expressions.forall(_.resolved) && childrenResolved && !hasSpecialExpressions
}
}
@@ -67,7 +68,7 @@ case class Generate(
generator.resolved &&
childrenResolved &&
generator.elementTypes.length == generatorOutput.length &&
- !generatorOutput.exists(!_.resolved)
+ generatorOutput.forall(_.resolved)
}
// we don't want the gOutput to be taken as part of the expressions
@@ -187,7 +188,7 @@ case class WithWindowDefinition(
}
/**
- * @param order The ordering expressions
+ * @param order The ordering expressions, should all be [[AttributeReference]]
* @param global True means global sorting apply for entire data set,
* False means sorting only apply within the partition.
* @param child Child logical plan
@@ -197,6 +198,11 @@ case class Sort(
global: Boolean,
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
+
+ def hasNoEvaluation: Boolean = order.forall(_.child.isInstanceOf[AttributeReference])
+
+ override lazy val resolved: Boolean =
+ expressions.forall(_.resolved) && childrenResolved && hasNoEvaluation
}
case class Aggregate(
@@ -211,9 +217,11 @@ case class Aggregate(
}.nonEmpty
)
- !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions
+ expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions
}
+ lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this)
+
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 2dcfa19fec383..f4d1dbaf28efe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -86,14 +86,6 @@ sealed trait Partitioning {
*/
def satisfies(required: Distribution): Boolean
- /**
- * Returns true iff all distribution guarantees made by this partitioning can also be made
- * for the `other` specified partitioning.
- * For example, two [[HashPartitioning HashPartitioning]]s are
- * only compatible if the `numPartitions` of them is the same.
- */
- def compatibleWith(other: Partitioning): Boolean
-
/** Returns the expressions that are used to key the partitioning. */
def keyExpressions: Seq[Expression]
}
@@ -104,11 +96,6 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case _ => false
}
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case UnknownPartitioning(_) => true
- case _ => false
- }
-
override def keyExpressions: Seq[Expression] = Nil
}
@@ -117,11 +104,6 @@ case object SinglePartition extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case SinglePartition => true
- case _ => false
- }
-
override def keyExpressions: Seq[Expression] = Nil
}
@@ -130,11 +112,6 @@ case object BroadcastPartitioning extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case SinglePartition => true
- case _ => false
- }
-
override def keyExpressions: Seq[Expression] = Nil
}
@@ -159,12 +136,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case BroadcastPartitioning => true
- case h: HashPartitioning if h == this => true
- case _ => false
- }
-
override def keyExpressions: Seq[Expression] = expressions
}
@@ -199,11 +170,5 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case BroadcastPartitioning => true
- case r: RangePartitioning if r == this => true
- case _ => false
- }
-
override def keyExpressions: Seq[Expression] = ordering.map(_.child)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 07412e73b6a5b..5a7c25b8d508d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -45,6 +45,7 @@ object DateTimeUtils {
final val to2001 = -11323
// this is year -17999, calculation: 50 * daysIn400Year
+ final val YearZero = -17999
final val toYearZero = to2001 + 7304850
@transient lazy val defaultTimeZone = TimeZone.getDefault
@@ -573,4 +574,243 @@ object DateTimeUtils {
dayInYear - 334
}
}
+
+ /**
+ * The number of days for each month (not leap year)
+ */
+ private val monthDays = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31)
+
+ /**
+ * Returns the date value for the first day of the given month.
+ * The month is expressed in months since year zero (17999 BC), starting from 0.
+ */
+ private def firstDayOfMonth(absoluteMonth: Int): Int = {
+ val absoluteYear = absoluteMonth / 12
+ var monthInYear = absoluteMonth - absoluteYear * 12
+ var date = getDateFromYear(absoluteYear)
+ if (monthInYear >= 2 && isLeapYear(absoluteYear + YearZero)) {
+ date += 1
+ }
+ while (monthInYear > 0) {
+ date += monthDays(monthInYear - 1)
+ monthInYear -= 1
+ }
+ date
+ }
+
+ /**
+ * Returns the date value for January 1 of the given year.
+ * The year is expressed in years since year zero (17999 BC), starting from 0.
+ */
+ private def getDateFromYear(absoluteYear: Int): Int = {
+ val absoluteDays = (absoluteYear * 365 + absoluteYear / 400 - absoluteYear / 100
+ + absoluteYear / 4)
+ absoluteDays - toYearZero
+ }
+
+ /**
+ * Add date and year-month interval.
+ * Returns a date value, expressed in days since 1.1.1970.
+ */
+ def dateAddMonths(days: Int, months: Int): Int = {
+ val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months
+ val currentMonthInYear = absoluteMonth % 12
+ val currentYear = absoluteMonth / 12
+ val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0
+ val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay
+
+ val dayOfMonth = getDayOfMonth(days)
+ val currentDayInMonth = if (getDayOfMonth(days + 1) == 1 || dayOfMonth >= lastDayOfMonth) {
+ // last day of the month
+ lastDayOfMonth
+ } else {
+ dayOfMonth
+ }
+ firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1
+ }
+
+ /**
+ * Add timestamp and full interval.
+ * Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00.
+ */
+ def timestampAddInterval(start: Long, months: Int, microseconds: Long): Long = {
+ val days = millisToDays(start / 1000L)
+ val newDays = dateAddMonths(days, months)
+ daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds
+ }
+
+ /**
+ * Returns the last dayInMonth in the month it belongs to. The date is expressed
+ * in days since 1.1.1970. the return value starts from 1.
+ */
+ private def getLastDayInMonthOfMonth(date: Int): Int = {
+ var (year, dayInYear) = getYearAndDayInYear(date)
+ if (isLeapYear(year)) {
+ if (dayInYear > 31 && dayInYear <= 60) {
+ return 29
+ } else if (dayInYear > 60) {
+ dayInYear = dayInYear - 1
+ }
+ }
+ if (dayInYear <= 31) {
+ 31
+ } else if (dayInYear <= 59) {
+ 28
+ } else if (dayInYear <= 90) {
+ 31
+ } else if (dayInYear <= 120) {
+ 30
+ } else if (dayInYear <= 151) {
+ 31
+ } else if (dayInYear <= 181) {
+ 30
+ } else if (dayInYear <= 212) {
+ 31
+ } else if (dayInYear <= 243) {
+ 31
+ } else if (dayInYear <= 273) {
+ 30
+ } else if (dayInYear <= 304) {
+ 31
+ } else if (dayInYear <= 334) {
+ 30
+ } else {
+ 31
+ }
+ }
+
+ /**
+ * Returns number of months between time1 and time2. time1 and time2 are expressed in
+ * microseconds since 1.1.1970.
+ *
+ * If time1 and time2 having the same day of month, or both are the last day of month,
+ * it returns an integer (time under a day will be ignored).
+ *
+ * Otherwise, the difference is calculated based on 31 days per month, and rounding to
+ * 8 digits.
+ */
+ def monthsBetween(time1: Long, time2: Long): Double = {
+ val millis1 = time1 / 1000L
+ val millis2 = time2 / 1000L
+ val date1 = millisToDays(millis1)
+ val date2 = millisToDays(millis2)
+ // TODO(davies): get year, month, dayOfMonth from single function
+ val dayInMonth1 = getDayOfMonth(date1)
+ val dayInMonth2 = getDayOfMonth(date2)
+ val months1 = getYear(date1) * 12 + getMonth(date1)
+ val months2 = getYear(date2) * 12 + getMonth(date2)
+
+ if (dayInMonth1 == dayInMonth2 || (dayInMonth1 == getLastDayInMonthOfMonth(date1)
+ && dayInMonth2 == getLastDayInMonthOfMonth(date2))) {
+ return (months1 - months2).toDouble
+ }
+ // milliseconds is enough for 8 digits precision on the right side
+ val timeInDay1 = millis1 - daysToMillis(date1)
+ val timeInDay2 = millis2 - daysToMillis(date2)
+ val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY
+ val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0
+ // rounding to 8 digits
+ math.round(diff * 1e8) / 1e8
+ }
+
+ /*
+ * Returns day of week from String. Starting from Thursday, marked as 0.
+ * (Because 1970-01-01 is Thursday).
+ */
+ def getDayOfWeekFromString(string: UTF8String): Int = {
+ val dowString = string.toString.toUpperCase
+ dowString match {
+ case "SU" | "SUN" | "SUNDAY" => 3
+ case "MO" | "MON" | "MONDAY" => 4
+ case "TU" | "TUE" | "TUESDAY" => 5
+ case "WE" | "WED" | "WEDNESDAY" => 6
+ case "TH" | "THU" | "THURSDAY" => 0
+ case "FR" | "FRI" | "FRIDAY" => 1
+ case "SA" | "SAT" | "SATURDAY" => 2
+ case _ => -1
+ }
+ }
+
+ /**
+ * Returns the first date which is later than startDate and is of the given dayOfWeek.
+ * dayOfWeek is an integer ranges in [0, 6], and 0 is Thu, 1 is Fri, etc,.
+ */
+ def getNextDateForDayOfWeek(startDate: Int, dayOfWeek: Int): Int = {
+ startDate + 1 + ((dayOfWeek - 1 - startDate) % 7 + 7) % 7
+ }
+
+ /**
+ * Returns last day of the month for the given date. The date is expressed in days
+ * since 1.1.1970.
+ */
+ def getLastDayOfMonth(date: Int): Int = {
+ var (year, dayInYear) = getYearAndDayInYear(date)
+ if (isLeapYear(year)) {
+ if (dayInYear > 31 && dayInYear <= 60) {
+ return date + (60 - dayInYear)
+ } else if (dayInYear > 60) {
+ dayInYear = dayInYear - 1
+ }
+ }
+ val lastDayOfMonthInYear = if (dayInYear <= 31) {
+ 31
+ } else if (dayInYear <= 59) {
+ 59
+ } else if (dayInYear <= 90) {
+ 90
+ } else if (dayInYear <= 120) {
+ 120
+ } else if (dayInYear <= 151) {
+ 151
+ } else if (dayInYear <= 181) {
+ 181
+ } else if (dayInYear <= 212) {
+ 212
+ } else if (dayInYear <= 243) {
+ 243
+ } else if (dayInYear <= 273) {
+ 273
+ } else if (dayInYear <= 304) {
+ 304
+ } else if (dayInYear <= 334) {
+ 334
+ } else {
+ 365
+ }
+ date + (lastDayOfMonthInYear - dayInYear)
+ }
+
+ private val TRUNC_TO_YEAR = 1
+ private val TRUNC_TO_MONTH = 2
+ private val TRUNC_INVALID = -1
+
+ /**
+ * Returns the trunc date from original date and trunc level.
+ * Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2.
+ */
+ def truncDate(d: Int, level: Int): Int = {
+ if (level == TRUNC_TO_YEAR) {
+ d - DateTimeUtils.getDayInYear(d) + 1
+ } else if (level == TRUNC_TO_MONTH) {
+ d - DateTimeUtils.getDayOfMonth(d) + 1
+ } else {
+ throw new Exception(s"Invalid trunc level: $level")
+ }
+ }
+
+ /**
+ * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID,
+ * TRUNC_INVALID means unsupported truncate level.
+ */
+ def parseTruncLevel(format: UTF8String): Int = {
+ if (format == null) {
+ TRUNC_INVALID
+ } else {
+ format.toString.toUpperCase match {
+ case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
+ case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
+ case _ => TRUNC_INVALID
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 40bf4b299c990..e0667c629486d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -95,7 +95,7 @@ private[sql] object TypeCollection {
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
*/
- val NumericAndInterval = TypeCollection(NumericType, IntervalType)
+ val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType)
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
new file mode 100644
index 0000000000000..14a7285877622
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+
+abstract class ArrayData extends SpecializedGetters with Serializable {
+ // todo: remove this after we handle all types.(map type need special getter)
+ def get(ordinal: Int): Any
+
+ def numElements(): Int
+
+ // todo: need a more efficient way to iterate array type.
+ def toArray(): Array[Any] = {
+ val n = numElements()
+ val values = new Array[Any](n)
+ var i = 0
+ while (i < n) {
+ if (isNullAt(i)) {
+ values(i) = null
+ } else {
+ values(i) = get(i)
+ }
+ i += 1
+ }
+ values
+ }
+
+ override def toString(): String = toArray.mkString("[", ",", "]")
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[ArrayData]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[ArrayData]
+ if (other eq null) {
+ return false
+ }
+
+ val len = numElements()
+ if (len != other.numElements()) {
+ return false
+ }
+
+ var i = 0
+ while (i < len) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (!isNullAt(i)) {
+ val o1 = get(i)
+ val o2 = other.get(i)
+ o1 match {
+ case b1: Array[Byte] =>
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ case f1: Float if java.lang.Float.isNaN(f1) =>
+ if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ return false
+ }
+ case d1: Double if java.lang.Double.isNaN(d1) =>
+ if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ return false
+ }
+ case _ => if (o1 != o2) {
+ return false
+ }
+ }
+ }
+ i += 1
+ }
+ true
+ }
+
+ override def hashCode: Int = {
+ var result: Int = 37
+ var i = 0
+ val len = numElements()
+ while (i < len) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ get(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case a: Array[Byte] => java.util.Arrays.hashCode(a)
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
similarity index 64%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
index 87c6e9e6e5e2c..3565f52c21f69 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
@@ -22,16 +22,19 @@ import org.apache.spark.annotation.DeveloperApi
/**
* :: DeveloperApi ::
- * The data type representing time intervals.
+ * The data type representing calendar time intervals. The calendar time interval is stored
+ * internally in two components: number of months the number of microseconds.
*
- * Please use the singleton [[DataTypes.IntervalType]].
+ * Note that calendar intervals are not comparable.
+ *
+ * Please use the singleton [[DataTypes.CalendarIntervalType]].
*/
@DeveloperApi
-class IntervalType private() extends DataType {
+class CalendarIntervalType private() extends DataType {
- override def defaultSize: Int = 4096
+ override def defaultSize: Int = 16
- private[spark] override def asNullable: IntervalType = this
+ private[spark] override def asNullable: CalendarIntervalType = this
}
-case object IntervalType extends IntervalType
+case object CalendarIntervalType extends CalendarIntervalType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 591fb26e67c4a..f4428c2e8b202 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -142,12 +142,21 @@ object DataType {
("type", JString("struct"))) =>
StructType(fields.map(parseStructField))
+ // Scala/Java UDT
case JSortedObject(
("class", JString(udtClass)),
("pyClass", _),
("sqlType", _),
("type", JString("udt"))) =>
Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
+
+ // Python UDT
+ case JSortedObject(
+ ("pyClass", JString(pyClass)),
+ ("serializedClass", JString(serialized)),
+ ("sqlType", v: JValue),
+ ("type", JString("udt"))) =>
+ new PythonUserDefinedType(parseDataType(v), pyClass, serialized)
}
private def parseStructField(json: JValue): StructField = json match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index bc689810bc292..c0155eeb450a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -188,6 +188,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
* @return true if successful, false if overflow would occur
*/
def changePrecision(precision: Int, scale: Int): Boolean = {
+ // fast path for UnsafeProjection
+ if (precision == this.precision && scale == this.scale) {
+ return true
+ }
// First, update our longVal if we can, or transfer over to using a BigDecimal
if (decimalVal.eq(null)) {
if (scale < _scale) {
@@ -224,7 +228,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
decimalVal = newVal
} else {
// We're still using Longs, but we should check whether we match the new precision
- val p = POW_10(math.min(_precision, MAX_LONG_DIGITS))
+ val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
if (longVal <= -p || longVal >= p) {
// Note that we shouldn't have been able to fix this by switching to BigDecimal
return false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
new file mode 100644
index 0000000000000..35ace673fb3da
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval}
+
+class GenericArrayData(array: Array[Any]) extends ArrayData {
+ private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T]
+
+ override def toArray(): Array[Any] = array
+
+ override def get(ordinal: Int): Any = array(ordinal)
+
+ override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null
+
+ override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+
+ override def getByte(ordinal: Int): Byte = getAs(ordinal)
+
+ override def getShort(ordinal: Int): Short = getAs(ordinal)
+
+ override def getInt(ordinal: Int): Int = getAs(ordinal)
+
+ override def getLong(ordinal: Int): Long = getAs(ordinal)
+
+ override def getFloat(ordinal: Int): Float = getAs(ordinal)
+
+ override def getDouble(ordinal: Int): Double = getAs(ordinal)
+
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
+
+ override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+
+ override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+
+ override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+
+ override def numElements(): Int = array.length
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index e47cfb4833bd8..4305903616bd9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -45,6 +45,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
/** Paired Python UDT class, if exists. */
def pyUDT: String = null
+ /** Serialized Python UDT class, if exists. */
+ def serializedPyClass: String = null
+
/**
* Convert the user type to a SQL datum
*
@@ -82,3 +85,29 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
override private[sql] def acceptsType(dataType: DataType) =
this.getClass == dataType.getClass
}
+
+/**
+ * ::DeveloperApi::
+ * The user defined type in Python.
+ *
+ * Note: This can only be accessed via Python UDF, or accessed as serialized object.
+ */
+private[sql] class PythonUserDefinedType(
+ val sqlType: DataType,
+ override val pyUDT: String,
+ override val serializedPyClass: String) extends UserDefinedType[Any] {
+
+ /* The serialization is handled by UDT class in Python */
+ override def serialize(obj: Any): Any = obj
+ override def deserialize(datam: Any): Any = datam
+
+ /* There is no Java class for Python UDT */
+ override def userClass: java.lang.Class[Any] = null
+
+ override private[sql] def jsonValue: JValue = {
+ ("type" -> "udt") ~
+ ("pyClass" -> pyUDT) ~
+ ("serializedClass" -> serializedPyClass) ~
+ ("sqlType" -> sqlType.jsonValue)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ed645b618dc9b..a86cefe941e8e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -145,15 +145,15 @@ class AnalysisSuite extends AnalysisTest {
'e / 'e as 'div5))
val pl = plan.asInstanceOf[Project].projectList
- // StringType will be promoted into Double
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
- assert(pl(3).dataType == DoubleType)
+ // StringType will be promoted into Decimal(38, 18)
+ assert(pl(3).dataType == DecimalType(38, 29))
assert(pl(4).dataType == DoubleType)
}
- test("pull out nondeterministic expressions from unary LogicalPlan") {
+ test("pull out nondeterministic expressions from RepartitionByExpression") {
val plan = RepartitionByExpression(Seq(Rand(33)), testRelation)
val projected = Alias(Rand(33), "_nondeterministic")()
val expected =
@@ -162,4 +162,42 @@ class AnalysisSuite extends AnalysisTest {
Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}
+
+ test("pull out nondeterministic expressions from Sort") {
+ val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation)
+ val analyzed = caseSensitiveAnalyzer.execute(plan)
+ analyzed.transform {
+ case s: Sort if s.expressions.exists(!_.deterministic) =>
+ fail("nondeterministic expressions are not allowed in Sort")
+ }
+ }
+
+ test("remove still-need-evaluate ordering expressions from sort") {
+ val a = testRelation2.output(0)
+ val b = testRelation2.output(1)
+
+ def makeOrder(e: Expression): SortOrder = SortOrder(e, Ascending)
+
+ val noEvalOrdering = makeOrder(a)
+ val noEvalOrderingWithAlias = makeOrder(Alias(Alias(b, "name1")(), "name2")())
+
+ val needEvalExpr = Coalesce(Seq(a, Literal("1")))
+ val needEvalExpr2 = Coalesce(Seq(a, b))
+ val needEvalOrdering = makeOrder(needEvalExpr)
+ val needEvalOrdering2 = makeOrder(needEvalExpr2)
+
+ val plan = Sort(
+ Seq(noEvalOrdering, noEvalOrderingWithAlias, needEvalOrdering, needEvalOrdering2),
+ false, testRelation2)
+
+ val evaluatedOrdering = makeOrder(AttributeReference("_sortCondition", StringType)())
+ val materializedExprs = Seq(needEvalExpr, needEvalExpr2).map(e => Alias(e, "_sortCondition")())
+
+ val expected =
+ Project(testRelation2.output,
+ Sort(Seq(makeOrder(a), makeOrder(b), evaluatedOrdering, evaluatedOrdering), false,
+ Project(testRelation2.output ++ materializedExprs, testRelation2)))
+
+ checkAnalysis(plan, expected)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index ad15136ee9a2f..a52e4cb4dfd9f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}
test("check types for unary arithmetic") {
- assertError(UnaryMinus('stringField), "type (numeric or interval)")
+ assertError(UnaryMinus('stringField), "type (numeric or calendarinterval)")
assertError(Abs('stringField), "expected to be of type numeric")
assertError(BitwiseNot('stringField), "expected to be of type integral")
}
@@ -78,8 +78,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
- assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type")
- assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type")
+ assertError(Add('booleanField, 'booleanField), "accepts (numeric or calendarinterval) type")
+ assertError(Subtract('booleanField, 'booleanField),
+ "accepts (numeric or calendarinterval) type")
assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")
@@ -166,10 +167,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments")
assertError(
CreateNamedStruct(Seq(1, "a", "b", 2.0)),
- "Odd position only allow foldable and not-null StringType expressions")
+ "Only foldable StringType expressions are allowed to appear at odd position")
assertError(
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
- "Odd position only allow foldable and not-null StringType expressions")
+ "Only foldable StringType expressions are allowed to appear at odd position")
+ assertError(
+ CreateNamedStruct(Seq(Literal.create(null, StringType), "a")),
+ "Field name should not be null")
}
test("check types for ROUND") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 4454d51b75877..70608771dd110 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.catalyst.analysis
+import java.sql.Timestamp
+
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
class HiveTypeCoercionSuite extends PlanTest {
@@ -116,7 +119,7 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldNotCast(IntegerType, MapType)
shouldNotCast(IntegerType, StructType)
- shouldNotCast(IntervalType, StringType)
+ shouldNotCast(CalendarIntervalType, StringType)
// Don't implicitly cast complex types to string.
shouldNotCast(ArrayType(StringType), StringType)
@@ -400,6 +403,33 @@ class HiveTypeCoercionSuite extends PlanTest {
}
}
+ test("rule for date/timestamp operations") {
+ val dateTimeOperations = HiveTypeCoercion.DateTimeOperations
+ val date = Literal(new java.sql.Date(0L))
+ val timestamp = Literal(new Timestamp(0L))
+ val interval = Literal(new CalendarInterval(0, 0))
+ val str = Literal("2015-01-01")
+
+ ruleTest(dateTimeOperations, Add(date, interval), Cast(TimeAdd(date, interval), DateType))
+ ruleTest(dateTimeOperations, Add(interval, date), Cast(TimeAdd(date, interval), DateType))
+ ruleTest(dateTimeOperations, Add(timestamp, interval),
+ Cast(TimeAdd(timestamp, interval), TimestampType))
+ ruleTest(dateTimeOperations, Add(interval, timestamp),
+ Cast(TimeAdd(timestamp, interval), TimestampType))
+ ruleTest(dateTimeOperations, Add(str, interval), Cast(TimeAdd(str, interval), StringType))
+ ruleTest(dateTimeOperations, Add(interval, str), Cast(TimeAdd(str, interval), StringType))
+
+ ruleTest(dateTimeOperations, Subtract(date, interval), Cast(TimeSub(date, interval), DateType))
+ ruleTest(dateTimeOperations, Subtract(timestamp, interval),
+ Cast(TimeSub(timestamp, interval), TimestampType))
+ ruleTest(dateTimeOperations, Subtract(str, interval), Cast(TimeSub(str, interval), StringType))
+
+ // interval operations should not be effected
+ ruleTest(dateTimeOperations, Add(interval, interval), Add(interval, interval))
+ ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval))
+ }
+
+
/**
* There are rules that need to not fire before child expressions get resolved.
* We use this test to make sure those rules do not fire early.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index e7e5231d32c9e..d03b0fbbfb2b2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -116,9 +116,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
test("Abs") {
testNumericDataTypes { convert =>
+ val input = Literal(convert(1))
+ val dataType = input.dataType
checkEvaluation(Abs(Literal(convert(0))), convert(0))
checkEvaluation(Abs(Literal(convert(1))), convert(1))
checkEvaluation(Abs(Literal(convert(-1))), convert(1))
+ checkEvaluation(Abs(Literal.create(null, dataType)), null)
}
}
@@ -170,6 +173,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Pmod(-7, 3), 2)
checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005)
checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1))
- checkEvaluation(Pmod(2L, Long.MaxValue), 2)
+ checkEvaluation(Pmod(2L, Long.MaxValue), 2L)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
index 648fbf5a4c30b..fa30fbe528479 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
@@ -30,8 +30,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(expr, expected)
}
- check(1.toByte, ~1.toByte)
- check(1000.toShort, ~1000.toShort)
+ // Need the extra toByte even though IntelliJ thought it's not needed.
+ check(1.toByte, (~1.toByte).toByte)
+ check(1000.toShort, (~1000.toShort).toShort)
check(1000000, ~1000000)
check(123456789123L, ~123456789123L)
@@ -45,8 +46,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(expr, expected)
}
- check(1.toByte, 2.toByte, 1.toByte & 2.toByte)
- check(1000.toShort, 2.toShort, 1000.toShort & 2.toShort)
+ // Need the extra toByte even though IntelliJ thought it's not needed.
+ check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte)
+ check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort)
check(1000000, 4, 1000000 & 4)
check(123456789123L, 5L, 123456789123L & 5L)
@@ -63,8 +65,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(expr, expected)
}
- check(1.toByte, 2.toByte, 1.toByte | 2.toByte)
- check(1000.toShort, 2.toShort, 1000.toShort | 2.toShort)
+ // Need the extra toByte even though IntelliJ thought it's not needed.
+ check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte)
+ check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort)
check(1000000, 4, 1000000 | 4)
check(123456789123L, 5L, 123456789123L | 5L)
@@ -81,8 +84,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(expr, expected)
}
- check(1.toByte, 2.toByte, 1.toByte ^ 2.toByte)
- check(1000.toShort, 2.toShort, 1000.toShort ^ 2.toShort)
+ // Need the extra toByte even though IntelliJ thought it's not needed.
+ check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte)
+ check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort)
check(1000000, 4, 1000000 ^ 4)
check(123456789123L, 5L, 123456789123L ^ 5L)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 408353cf70a49..1ad70733eae03 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -21,6 +21,7 @@ import java.sql.{Timestamp, Date}
import java.util.{TimeZone, Calendar}
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -43,6 +44,42 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(v, Literal(expected).dataType), expected)
}
+ private def checkNullCast(from: DataType, to: DataType): Unit = {
+ checkEvaluation(Cast(Literal.create(null, from), to), null)
+ }
+
+ test("null cast") {
+ import DataTypeTestUtils._
+
+ // follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic
+ // to ensure we test every possible cast situation here
+ atomicTypes.zip(atomicTypes).foreach { case (from, to) =>
+ checkNullCast(from, to)
+ }
+
+ atomicTypes.foreach(dt => checkNullCast(NullType, dt))
+ atomicTypes.foreach(dt => checkNullCast(dt, StringType))
+ checkNullCast(StringType, BinaryType)
+ checkNullCast(StringType, BooleanType)
+ checkNullCast(DateType, BooleanType)
+ checkNullCast(TimestampType, BooleanType)
+ numericTypes.foreach(dt => checkNullCast(dt, BooleanType))
+
+ checkNullCast(StringType, TimestampType)
+ checkNullCast(BooleanType, TimestampType)
+ checkNullCast(DateType, TimestampType)
+ numericTypes.foreach(dt => checkNullCast(dt, TimestampType))
+
+ atomicTypes.foreach(dt => checkNullCast(dt, DateType))
+
+ checkNullCast(StringType, CalendarIntervalType)
+ numericTypes.foreach(dt => checkNullCast(StringType, dt))
+ numericTypes.foreach(dt => checkNullCast(BooleanType, dt))
+ numericTypes.foreach(dt => checkNullCast(DateType, dt))
+ numericTypes.foreach(dt => checkNullCast(TimestampType, dt))
+ for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to)
+ }
+
test("cast string to date") {
var c = Calendar.getInstance()
c.set(2015, 0, 1, 0, 0, 0)
@@ -69,8 +106,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("cast string to timestamp") {
- checkEvaluation(Cast(Literal("123"), TimestampType),
- null)
+ checkEvaluation(Cast(Literal("123"), TimestampType), null)
var c = Calendar.getInstance()
c.set(2015, 0, 1, 0, 0, 0)
@@ -206,10 +242,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
- checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0))
+ checkEvaluation(cast(123L, DecimalType(3, 1)), null)
- // TODO: Fix the following bug and re-enable it.
- // checkEvaluation(cast(123L, DecimalType(2, 0)), null)
+ checkEvaluation(cast(123L, DecimalType(2, 0)), null)
}
test("cast from boolean") {
@@ -473,6 +508,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val array_notNull = Literal.create(Seq("123", "abc", ""),
ArrayType(StringType, containsNull = false))
+ checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
+
{
val ret = cast(array, ArrayType(IntegerType, containsNull = true))
assert(ret.resolved === true)
@@ -526,6 +563,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
Map("a" -> "123", "b" -> "abc", "c" -> ""),
MapType(StringType, StringType, valueContainsNull = false))
+ checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
+
{
val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true))
assert(ret.resolved === true)
@@ -580,6 +619,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("cast from struct") {
+ checkNullCast(
+ StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", IntegerType))),
+ StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", StringType))))
+
val struct = Literal.create(
InternalRow(
UTF8String.fromString("123"),
@@ -683,13 +730,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("complex casting") {
val complex = Literal.create(
- InternalRow(
- Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")),
- Map(
- UTF8String.fromString("a") -> UTF8String.fromString("123"),
- UTF8String.fromString("b") -> UTF8String.fromString("abc"),
- UTF8String.fromString("c") -> UTF8String.fromString("")),
- InternalRow(0)),
+ Row(
+ Seq("123", "abc", ""),
+ Map("a" ->"123", "b" -> "abc", "c" -> ""),
+ Row(0)),
StructType(Seq(
StructField("a",
ArrayType(StringType, containsNull = false), nullable = true),
@@ -709,23 +753,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("l", LongType, nullable = true)))))))
assert(ret.resolved === true)
- checkEvaluation(ret, InternalRow(
+ checkEvaluation(ret, Row(
Seq(123, null, null),
- Map(
- UTF8String.fromString("a") -> true,
- UTF8String.fromString("b") -> true,
- UTF8String.fromString("c") -> false),
- InternalRow(0L)))
+ Map("a" -> true, "b" -> true, "c" -> false),
+ Row(0L)))
}
test("case between string and interval") {
- import org.apache.spark.unsafe.types.Interval
+ import org.apache.spark.unsafe.types.CalendarInterval
- checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType),
- new Interval(-3, 7 * Interval.MICROS_PER_HOUR))
+ checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType),
+ new CalendarInterval(-3, 7 * CalendarInterval.MICROS_PER_HOUR))
checkEvaluation(Cast(Literal.create(
- new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType),
+ new CalendarInterval(15, -3 * CalendarInterval.MICROS_PER_DAY), CalendarIntervalType),
+ StringType),
"interval 1 years 3 months -3 days")
}
-
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index fc842772f3480..3fa246b69d1f1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -110,7 +110,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
expr.dataType match {
case ArrayType(StructType(fields), containsNull) =>
val field = fields.find(_.name == fieldName).get
- GetArrayStructFields(expr, field, fields.indexOf(field), containsNull)
+ GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull)
}
}
@@ -132,6 +132,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow)
+ checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
}
test("CreateStruct") {
@@ -139,26 +140,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val c1 = 'a.int.at(0)
val c3 = 'c.int.at(2)
checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
+ checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null))
}
test("CreateNamedStruct") {
- val row = InternalRow(1, 2, 3)
+ val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)
val c3 = 'c.int.at(2)
- checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row)
- }
-
- test("CreateNamedStruct with literal field") {
- val row = InternalRow(1, 2, 3)
- val c1 = 'a.int.at(0)
+ checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row)
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")),
- InternalRow(1, UTF8String.fromString("y")), row)
- }
-
- test("CreateNamedStruct from all literal fields") {
- checkEvaluation(
- CreateNamedStruct(Seq("a", "x", "b", 2.0)),
- InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty)
+ create_row(1, UTF8String.fromString("y")), row)
+ checkEvaluation(CreateNamedStruct(Seq("a", "x", "b", 2.0)),
+ create_row(UTF8String.fromString("x"), 2.0))
+ checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))),
+ create_row(null))
}
test("test dsl for complex type") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index b31d6661c8c1c..d26bcdb2902ab 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -149,6 +149,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Least(Seq(c1, c2, Literal(-1))), -1, row)
checkEvaluation(Least(Seq(c4, c5, c3, c3, Literal("a"))), "a", row)
+ val nullLiteral = Literal.create(null, IntegerType)
+ checkEvaluation(Least(Seq(nullLiteral, nullLiteral)), null)
checkEvaluation(Least(Seq(Literal(null), Literal(null))), null, InternalRow.empty)
checkEvaluation(Least(Seq(Literal(-1.0), Literal(2.5))), -1.0, InternalRow.empty)
checkEvaluation(Least(Seq(Literal(-1), Literal(2))), -1, InternalRow.empty)
@@ -188,6 +190,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Greatest(Seq(c1, c2, Literal(2))), 2, row)
checkEvaluation(Greatest(Seq(c4, c5, c3, Literal("ccc"))), "ccc", row)
+ val nullLiteral = Literal.create(null, IntegerType)
+ checkEvaluation(Greatest(Seq(nullLiteral, nullLiteral)), null)
checkEvaluation(Greatest(Seq(Literal(null), Literal(null))), null, InternalRow.empty)
checkEvaluation(Greatest(Seq(Literal(-1.0), Literal(2.5))), 2.5, InternalRow.empty)
checkEvaluation(Greatest(Seq(Literal(-1), Literal(2))), 2, InternalRow.empty)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index bdba6ce891386..6c15c05da3094 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -17,12 +17,14 @@
package org.apache.spark.sql.catalyst.expressions
-import java.sql.{Timestamp, Date}
+import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Calendar
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.{StringType, TimestampType, DateType}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.sql.types._
class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -31,58 +33,23 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime)
val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime)
+ test("datetime function current_date") {
+ val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
+ val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int]
+ val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis())
+ assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1)
+ }
+
+ test("datetime function current_timestamp") {
+ val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long])
+ val t1 = System.currentTimeMillis()
+ assert(math.abs(t1 - ct.getTime) < 5000)
+ }
+
test("DayOfYear") {
val sdfDay = new SimpleDateFormat("D")
- (2002 to 2004).foreach { y =>
- (0 to 11).foreach { m =>
- (0 to 5).foreach { i =>
- val c = Calendar.getInstance()
- c.set(y, m, 28, 0, 0, 0)
- c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
- sdfDay.format(c.getTime).toInt)
- }
- }
- }
-
(1998 to 2002).foreach { y =>
- (0 to 11).foreach { m =>
- (0 to 5).foreach { i =>
- val c = Calendar.getInstance()
- c.set(y, m, 28, 0, 0, 0)
- c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
- sdfDay.format(c.getTime).toInt)
- }
- }
- }
-
- (1969 to 1970).foreach { y =>
- (0 to 11).foreach { m =>
- (0 to 5).foreach { i =>
- val c = Calendar.getInstance()
- c.set(y, m, 28, 0, 0, 0)
- c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
- sdfDay.format(c.getTime).toInt)
- }
- }
- }
-
- (2402 to 2404).foreach { y =>
- (0 to 11).foreach { m =>
- (0 to 5).foreach { i =>
- val c = Calendar.getInstance()
- c.set(y, m, 28, 0, 0, 0)
- c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
- sdfDay.format(c.getTime).toInt)
- }
- }
- }
-
- (2398 to 2402).foreach { y =>
- (0 to 11).foreach { m =>
+ (0 to 3).foreach { m =>
(0 to 5).foreach { i =>
val c = Calendar.getInstance()
c.set(y, m, 28, 0, 0, 0)
@@ -92,6 +59,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}
+ checkEvaluation(DayOfYear(Literal.create(null, DateType)), null)
}
test("Year") {
@@ -101,7 +69,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013)
val c = Calendar.getInstance()
- (2000 to 2010).foreach { y =>
+ (2000 to 2002).foreach { y =>
(0 to 11 by 11).foreach { m =>
c.set(y, m, 28)
(0 to 5 * 24).foreach { i =>
@@ -139,20 +107,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Month(Cast(Literal(ts), DateType)), 11)
(2003 to 2004).foreach { y =>
- (0 to 11).foreach { m =>
- (0 to 5 * 24).foreach { i =>
- val c = Calendar.getInstance()
- c.set(y, m, 28, 0, 0, 0)
- c.add(Calendar.HOUR_OF_DAY, i)
- checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))),
- c.get(Calendar.MONTH) + 1)
- }
- }
- }
-
- (1999 to 2000).foreach { y =>
- (0 to 11).foreach { m =>
- (0 to 5 * 24).foreach { i =>
+ (0 to 3).foreach { m =>
+ (0 to 2 * 24).foreach { i =>
val c = Calendar.getInstance()
c.set(y, m, 28, 0, 0, 0)
c.add(Calendar.HOUR_OF_DAY, i)
@@ -246,4 +202,235 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
+ test("date_add") {
+ checkEvaluation(
+ DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(1)),
+ DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
+ checkEvaluation(
+ DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(-365)),
+ DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28")))
+ checkEvaluation(DateAdd(Literal.create(null, DateType), Literal(1)), null)
+ checkEvaluation(DateAdd(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)),
+ null)
+ checkEvaluation(DateAdd(Literal.create(null, DateType), Literal.create(null, IntegerType)),
+ null)
+ }
+
+ test("date_sub") {
+ checkEvaluation(
+ DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1)),
+ DateTimeUtils.fromJavaDate(Date.valueOf("2014-12-31")))
+ checkEvaluation(
+ DateSub(Literal(Date.valueOf("2015-01-01")), Literal(-1)),
+ DateTimeUtils.fromJavaDate(Date.valueOf("2015-01-02")))
+ checkEvaluation(DateSub(Literal.create(null, DateType), Literal(1)), null)
+ checkEvaluation(DateSub(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)),
+ null)
+ checkEvaluation(DateSub(Literal.create(null, DateType), Literal.create(null, IntegerType)),
+ null)
+ }
+
+ test("time_add") {
+ checkEvaluation(
+ TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")),
+ Literal(new CalendarInterval(1, 123000L))),
+ DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00.123")))
+
+ checkEvaluation(
+ TimeAdd(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))),
+ null)
+ checkEvaluation(
+ TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")),
+ Literal.create(null, CalendarIntervalType)),
+ null)
+ checkEvaluation(
+ TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)),
+ null)
+ }
+
+ test("time_sub") {
+ checkEvaluation(
+ TimeSub(Literal(Timestamp.valueOf("2016-03-31 10:00:00")),
+ Literal(new CalendarInterval(1, 0))),
+ DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00")))
+ checkEvaluation(
+ TimeSub(
+ Literal(Timestamp.valueOf("2016-03-30 00:00:01")),
+ Literal(new CalendarInterval(1, 2000000.toLong))),
+ DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-28 23:59:59")))
+
+ checkEvaluation(
+ TimeSub(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))),
+ null)
+ checkEvaluation(
+ TimeSub(Literal(Timestamp.valueOf("2016-01-29 10:00:00")),
+ Literal.create(null, CalendarIntervalType)),
+ null)
+ checkEvaluation(
+ TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)),
+ null)
+ }
+
+ test("add_months") {
+ checkEvaluation(AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(1)),
+ DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28")))
+ checkEvaluation(AddMonths(Literal(Date.valueOf("2016-03-30")), Literal(-1)),
+ DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
+ checkEvaluation(
+ AddMonths(Literal(Date.valueOf("2015-01-30")), Literal.create(null, IntegerType)),
+ null)
+ checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null)
+ checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)),
+ null)
+ }
+
+ test("months_between") {
+ checkEvaluation(
+ MonthsBetween(Literal(Timestamp.valueOf("1997-02-28 10:30:00")),
+ Literal(Timestamp.valueOf("1996-10-30 00:00:00"))),
+ 3.94959677)
+ checkEvaluation(
+ MonthsBetween(Literal(Timestamp.valueOf("2015-01-30 11:52:00")),
+ Literal(Timestamp.valueOf("2015-01-30 11:50:00"))),
+ 0.0)
+ checkEvaluation(
+ MonthsBetween(Literal(Timestamp.valueOf("2015-01-31 00:00:00")),
+ Literal(Timestamp.valueOf("2015-03-31 22:00:00"))),
+ -2.0)
+ checkEvaluation(
+ MonthsBetween(Literal(Timestamp.valueOf("2015-03-31 22:00:00")),
+ Literal(Timestamp.valueOf("2015-02-28 00:00:00"))),
+ 1.0)
+ val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00"))
+ val tnull = Literal.create(null, TimestampType)
+ checkEvaluation(MonthsBetween(t, tnull), null)
+ checkEvaluation(MonthsBetween(tnull, t), null)
+ checkEvaluation(MonthsBetween(tnull, tnull), null)
+ }
+
+ test("last_day") {
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-02-28"))), Date.valueOf("2015-02-28"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-03-27"))), Date.valueOf("2015-03-31"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-04-26"))), Date.valueOf("2015-04-30"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-05-25"))), Date.valueOf("2015-05-31"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-06-24"))), Date.valueOf("2015-06-30"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-07-23"))), Date.valueOf("2015-07-31"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-08-01"))), Date.valueOf("2015-08-31"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-09-02"))), Date.valueOf("2015-09-30"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-10-03"))), Date.valueOf("2015-10-31"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-11-04"))), Date.valueOf("2015-11-30"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2015-12-05"))), Date.valueOf("2015-12-31"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31"))
+ checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29"))
+ checkEvaluation(LastDay(Literal.create(null, DateType)), null)
+ }
+
+ test("next_day") {
+ def testNextDay(input: String, dayOfWeek: String, output: String): Unit = {
+ checkEvaluation(
+ NextDay(Literal(Date.valueOf(input)), NonFoldableLiteral(dayOfWeek)),
+ DateTimeUtils.fromJavaDate(Date.valueOf(output)))
+ checkEvaluation(
+ NextDay(Literal(Date.valueOf(input)), Literal(dayOfWeek)),
+ DateTimeUtils.fromJavaDate(Date.valueOf(output)))
+ }
+ testNextDay("2015-07-23", "Mon", "2015-07-27")
+ testNextDay("2015-07-23", "mo", "2015-07-27")
+ testNextDay("2015-07-23", "Tue", "2015-07-28")
+ testNextDay("2015-07-23", "tu", "2015-07-28")
+ testNextDay("2015-07-23", "we", "2015-07-29")
+ testNextDay("2015-07-23", "wed", "2015-07-29")
+ testNextDay("2015-07-23", "Thu", "2015-07-30")
+ testNextDay("2015-07-23", "TH", "2015-07-30")
+ testNextDay("2015-07-23", "Fri", "2015-07-24")
+ testNextDay("2015-07-23", "fr", "2015-07-24")
+
+ checkEvaluation(NextDay(Literal(Date.valueOf("2015-07-23")), Literal("xx")), null)
+ checkEvaluation(NextDay(Literal.create(null, DateType), Literal("xx")), null)
+ checkEvaluation(
+ NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
+ }
+
+ test("function to_date") {
+ checkEvaluation(
+ ToDate(Literal(Date.valueOf("2015-07-22"))),
+ DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22")))
+ checkEvaluation(ToDate(Literal.create(null, DateType)), null)
+ }
+
+ test("function trunc") {
+ def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
+ checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)),
+ expected)
+ checkEvaluation(
+ TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
+ expected)
+ }
+ val date = Date.valueOf("2015-07-22")
+ Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt =>
+ testTrunc(date, fmt, Date.valueOf("2015-01-01"))
+ }
+ Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
+ testTrunc(date, fmt, Date.valueOf("2015-07-01"))
+ }
+ testTrunc(date, "DD", null)
+ testTrunc(date, null, null)
+ testTrunc(null, "MON", null)
+ testTrunc(null, null, null)
+ }
+
+ test("from_unixtime") {
+ val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
+ val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
+ val sdf2 = new SimpleDateFormat(fmt2)
+ checkEvaluation(
+ FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(0)))
+ checkEvaluation(FromUnixTime(
+ Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(1000000)))
+ checkEvaluation(
+ FromUnixTime(Literal(-1000L), Literal(fmt2)), sdf2.format(new Timestamp(-1000000)))
+ checkEvaluation(
+ FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType)), null)
+ checkEvaluation(
+ FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss")), null)
+ checkEvaluation(FromUnixTime(Literal(1000L), Literal.create(null, StringType)), null)
+ checkEvaluation(
+ FromUnixTime(Literal(0L), Literal("not a valid format")), null)
+ }
+
+ test("unix_timestamp") {
+ val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
+ val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
+ val sdf2 = new SimpleDateFormat(fmt2)
+ val fmt3 = "yy-MM-dd"
+ val sdf3 = new SimpleDateFormat(fmt3)
+ val date1 = Date.valueOf("2015-07-24")
+ checkEvaluation(
+ UnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L)
+ checkEvaluation(UnixTimestamp(
+ Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L)
+ checkEvaluation(
+ UnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L)
+ checkEvaluation(
+ UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")),
+ DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L)
+ checkEvaluation(
+ UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L)
+ checkEvaluation(UnixTimestamp(
+ Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)),
+ DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L)
+ val t1 = UnixTimestamp(
+ CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
+ val t2 = UnixTimestamp(
+ CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
+ assert(t2 - t1 <= 1)
+ checkEvaluation(
+ UnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null)
+ checkEvaluation(
+ UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null)
+ checkEvaluation(UnixTimestamp(
+ Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L)
+ checkEvaluation(
+ UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala
deleted file mode 100644
index 1618c24871c60..0000000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-
-class DatetimeFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
- test("datetime function current_date") {
- val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
- val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int]
- val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis())
- assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1)
- }
-
- test("datetime function current_timestamp") {
- val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long])
- val t1 = System.currentTimeMillis()
- assert(math.abs(t1 - ct.getTime) < 5000)
- }
-
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index ab0cdc857c80e..3c05e5c3b833c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -65,7 +65,7 @@ trait ExpressionEvalHelper {
protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
expression.foreach {
- case n: Nondeterministic => n.initialize()
+ case n: Nondeterministic => n.setInitialValues()
case _ =>
}
expression.eval(inputRow)
@@ -82,6 +82,7 @@ trait ExpressionEvalHelper {
s"""
|Code generation of $expression failed:
|$e
+ |${e.getStackTraceString}
""".stripMargin)
}
}
@@ -114,7 +115,7 @@ trait ExpressionEvalHelper {
val actual = plan(inputRow).get(0, expression.dataType)
if (!checkResult(actual, expected)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
- fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+ fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input")
}
}
@@ -146,7 +147,8 @@ trait ExpressionEvalHelper {
if (actual != expectedRow) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
- fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input")
+ fail("Incorrect Evaluation in codegen mode: " +
+ s"$expression, actual: $actual, expected: $expectedRow$input")
}
if (actual.copy() != expectedRow) {
fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow")
@@ -163,12 +165,21 @@ trait ExpressionEvalHelper {
expression)
val unsafeRow = plan(inputRow)
- // UnsafeRow cannot be compared with GenericInternalRow directly
- val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow)
- val expectedRow = InternalRow(expected)
- if (actual != expectedRow) {
- val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
- fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input")
+ val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
+
+ if (expected == null) {
+ if (!unsafeRow.isNullAt(0)) {
+ val expectedRow = InternalRow(expected)
+ fail("Incorrect evaluation in unsafe mode: " +
+ s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
+ }
+ } else {
+ val lit = InternalRow(expected)
+ val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit)
+ if (unsafeRow != expectedRow) {
+ fail("Incorrect evaluation in unsafe mode: " +
+ s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 21459a7c69838..9fcb548af6bbb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -110,35 +110,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null))
}
- test("conv") {
- checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
- checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
- checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
- checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
- checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
- checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
- checkEvaluation(
- Conv(Literal("1234"), Literal(10), Literal(37)), null)
- checkEvaluation(
- Conv(Literal(""), Literal(10), Literal(16)), null)
- checkEvaluation(
- Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
- // If there is an invalid digit in the number, the longest valid prefix should be converted.
- checkEvaluation(
- Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
- }
-
private def checkNaN(
- expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
+ expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
checkNaNWithoutCodegen(expression, inputRow)
checkNaNWithGeneratedProjection(expression, inputRow)
checkNaNWithOptimization(expression, inputRow)
}
private def checkNaNWithoutCodegen(
- expression: Expression,
- expected: Any,
- inputRow: InternalRow = EmptyRow): Unit = {
+ expression: Expression,
+ expected: Any,
+ inputRow: InternalRow = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
@@ -149,7 +131,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
-
private def checkNaNWithGeneratedProjection(
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {
@@ -172,6 +153,25 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
}
+ test("conv") {
+ checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
+ checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
+ checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
+ checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
+ checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
+ checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
+ checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ Conv(Literal("1234"), Literal(10), Literal(37)), null)
+ checkEvaluation(
+ Conv(Literal(""), Literal(10), Literal(16)), null)
+ checkEvaluation(
+ Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
+ // If there is an invalid digit in the number, the longest valid prefix should be converted.
+ checkEvaluation(
+ Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
+ }
+
test("e") {
testLeaf(EulerNumber, math.E)
}
@@ -417,7 +417,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("round") {
- val domain = -6 to 6
+ val scales = -6 to 6
val doublePi: Double = math.Pi
val shortPi: Short = 31415
val intPi: Int = 314159265
@@ -437,17 +437,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
Seq.fill(7)(31415926535897932L)
- val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
- BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
- BigDecimal(3.141593), BigDecimal(3.1415927))
-
- domain.zipWithIndex.foreach { case (scale, i) =>
+ scales.zipWithIndex.foreach { case (scale, i) =>
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
}
+ val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
+ BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
+ BigDecimal(3.141593), BigDecimal(3.1415927))
// round_scale > current_scale would result in precision increase
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
(0 to 7).foreach { i =>
@@ -456,5 +455,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
(8 to 10).foreach { scale =>
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
}
+
+ DataTypeTestUtils.numericTypes.foreach { dataType =>
+ checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null)
+ checkEvaluation(Round(Literal.create(null, dataType),
+ Literal.create(null, IntegerType)), null)
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala
new file mode 100644
index 0000000000000..31ecf4a9e810a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.types._
+
+
+/**
+ * A literal value that is not foldable. Used in expression codegen testing to test code path
+ * that behave differently based on foldable values.
+ */
+case class NonFoldableLiteral(value: Any, dataType: DataType)
+ extends LeafExpression with CodegenFallback {
+
+ override def foldable: Boolean = false
+ override def nullable: Boolean = true
+
+ override def toString: String = if (value != null) value.toString else "null"
+
+ override def eval(input: InternalRow): Any = value
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ Literal.create(value, dataType).genCode(ctx, ev)
+ }
+}
+
+
+object NonFoldableLiteral {
+ def apply(value: Any): NonFoldableLiteral = {
+ val lit = Literal(value)
+ NonFoldableLiteral(lit.value, lit.dataType)
+ }
+ def create(value: Any, dataType: DataType): NonFoldableLiteral = {
+ val lit = Literal.create(value, dataType)
+ NonFoldableLiteral(lit.value, lit.dataType)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala
similarity index 76%
rename from sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala
index 99e11fd64b2b9..bf1c930c0bd0b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala
@@ -15,18 +15,20 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.expression
+package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions. ExpressionEvalHelper
-import org.apache.spark.sql.execution.expressions.{SparkPartitionID, MonotonicallyIncreasingID}
class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
test("MonotonicallyIncreasingID") {
- checkEvaluation(MonotonicallyIncreasingID(), 0)
+ checkEvaluation(MonotonicallyIncreasingID(), 0L)
}
test("SparkPartitionID") {
- checkEvaluation(SparkPartitionID, 0)
+ checkEvaluation(SparkPartitionID(), 0)
+ }
+
+ test("InputFileName") {
+ checkEvaluation(InputFileName(), "")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
index 698c81ba24482..4a644d136f09c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
@@ -20,9 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import org.scalatest.Matchers._
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.DoubleType
-
class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -30,4 +27,9 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001)
checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001)
}
+
+ test("SPARK-9127 codegen with long seed") {
+ checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001)
+ checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 3d294fda5d103..07b952531ec2e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -348,6 +348,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 "))
checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 "))
// scalastyle:on
+ checkEvaluation(StringTrim(Literal.create(null, StringType)), null)
+ checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null)
+ checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null)
}
test("FORMAT") {
@@ -391,6 +394,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val s3 = 'c.string.at(2)
val s4 = 'd.int.at(3)
val row1 = create_row("aaads", "aa", "zz", 1)
+ val row2 = create_row(null, "aa", "zz", 0)
+ val row3 = create_row("aaads", null, "zz", 0)
+ val row4 = create_row(null, null, null, 0)
checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1)
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1)
@@ -402,6 +408,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringLocate(s2, s1, s4), 2, row1)
checkEvaluation(new StringLocate(s3, s1), 0, row1)
checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1)
+ checkEvaluation(new StringLocate(s2, s1), null, row2)
+ checkEvaluation(new StringLocate(s2, s1), null, row3)
+ checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4)
}
test("LPAD/RPAD") {
@@ -448,6 +457,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val row1 = create_row("abccc")
checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
checkEvaluation(StringReverse(s), "cccba", row1)
+ checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
}
test("SPACE") {
@@ -466,6 +476,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val row1 = create_row("100-200", "(\\d+)", "num")
val row2 = create_row("100-200", "(\\d+)", "###")
val row3 = create_row("100-200", "(-)", "###")
+ val row4 = create_row(null, "(\\d+)", "###")
+ val row5 = create_row("100-200", null, "###")
+ val row6 = create_row("100-200", "(-)", null)
val s = 's.string.at(0)
val p = 'p.string.at(1)
@@ -475,6 +488,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(expr, "num-num", row1)
checkEvaluation(expr, "###-###", row2)
checkEvaluation(expr, "100###200", row3)
+ checkEvaluation(expr, null, row4)
+ checkEvaluation(expr, null, row5)
+ checkEvaluation(expr, null, row6)
}
test("RegexExtract") {
@@ -482,6 +498,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
val row3 = create_row("100-200", "(\\d+).*", 1)
val row4 = create_row("100-200", "([a-z])", 1)
+ val row5 = create_row(null, "([a-z])", 1)
+ val row6 = create_row("100-200", null, 1)
+ val row7 = create_row("100-200", "([a-z])", null)
val s = 's.string.at(0)
val p = 'p.string.at(1)
@@ -492,6 +511,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(expr, "200", row2)
checkEvaluation(expr, "100", row3)
checkEvaluation(expr, "", row4) // will not match anything, empty string get
+ checkEvaluation(expr, null, row5)
+ checkEvaluation(expr, null, row6)
+ checkEvaluation(expr, null, row7)
val expr1 = new RegExpExtract(s, p)
checkEvaluation(expr1, "100", row1)
@@ -501,11 +523,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val s1 = 'a.string.at(0)
val s2 = 'b.string.at(1)
val row1 = create_row("aa2bb3cc", "[1-9]+")
+ val row2 = create_row(null, "[1-9]+")
+ val row3 = create_row("aa2bb3cc", null)
checkEvaluation(
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1)
checkEvaluation(
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
+ checkEvaluation(StringSplit(s1, s2), null, row2)
+ checkEvaluation(StringSplit(s1, s2), null, row3)
}
test("length for string / binary") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
index 48b7dc57451a3..c6b4c729de2f9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -39,6 +39,7 @@ class UnsafeFixedWidthAggregationMapSuite
private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
private def emptyAggregationBuffer: InternalRow = InternalRow(0)
+ private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
private var memoryManager: TaskMemoryManager = null
@@ -54,13 +55,13 @@ class UnsafeFixedWidthAggregationMapSuite
}
test("supported schemas") {
+ assert(supportsAggregationBufferSchema(
+ StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
+ assert(!supportsAggregationBufferSchema(
+ StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil)))
assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
- assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))
-
assert(
!supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
- assert(
- !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
}
test("empty map") {
@@ -69,7 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite
aggBufferSchema,
groupKeySchema,
memoryManager,
- 1024, // initial capacity
+ 1024, // initial capacity,
+ PAGE_SIZE_BYTES,
false // disable perf metrics
)
assert(!map.iterator().hasNext)
@@ -83,6 +85,7 @@ class UnsafeFixedWidthAggregationMapSuite
groupKeySchema,
memoryManager,
1024, // initial capacity
+ PAGE_SIZE_BYTES,
false // disable perf metrics
)
val groupKey = InternalRow(UTF8String.fromString("cats"))
@@ -109,6 +112,7 @@ class UnsafeFixedWidthAggregationMapSuite
groupKeySchema,
memoryManager,
128, // initial capacity
+ PAGE_SIZE_BYTES,
false // disable perf metrics
)
val rand = new Random(42)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 2834b54e8fb2e..a0e1701339ea7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -46,7 +46,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow.getLong(1) === 1)
assert(unsafeRow.getInt(2) === 2)
- // We can copy UnsafeRows as long as they don't reference ObjectPools
val unsafeRowCopy = unsafeRow.copy()
assert(unsafeRowCopy.getLong(0) === 0)
assert(unsafeRowCopy.getLong(1) === 1)
@@ -122,8 +121,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
FloatType,
DoubleType,
StringType,
- BinaryType
- // DecimalType.Default,
+ BinaryType,
+ DecimalType.USER_DEFAULT
// ArrayType(IntegerType)
)
val converter = UnsafeProjection.create(fieldTypes)
@@ -146,11 +145,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(createdFromNull.getShort(3) === 0)
assert(createdFromNull.getInt(4) === 0)
assert(createdFromNull.getLong(5) === 0)
- assert(java.lang.Float.isNaN(createdFromNull.getFloat(6)))
- assert(java.lang.Double.isNaN(createdFromNull.getDouble(7)))
+ assert(createdFromNull.getFloat(6) === 0.0f)
+ assert(createdFromNull.getDouble(7) === 0.0d)
assert(createdFromNull.getUTF8String(8) === null)
assert(createdFromNull.getBinary(9) === null)
- // assert(createdFromNull.get(10) === null)
+ assert(createdFromNull.getDecimal(10, 10, 0) === null)
// assert(createdFromNull.get(11) === null)
// If we have an UnsafeRow with columns that are initially non-null and we null out those
@@ -168,7 +167,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
r.setDouble(7, 700)
r.update(8, UTF8String.fromString("hello"))
r.update(9, "world".getBytes)
- // r.update(10, Decimal(10))
+ r.setDecimal(10, Decimal(10), 10)
// r.update(11, Array(11))
r
}
@@ -184,7 +183,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9))
- // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+ assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
+ rowWithNoNullColumns.getDecimal(10, 10, 0))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
for (i <- fieldTypes.indices) {
@@ -203,7 +203,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
setToNullAfterCreation.setDouble(7, 700)
// setToNullAfterCreation.update(8, UTF8String.fromString("hello"))
// setToNullAfterCreation.update(9, "world".getBytes)
- // setToNullAfterCreation.update(10, Decimal(10))
+ setToNullAfterCreation.setDecimal(10, Decimal(10), 10)
// setToNullAfterCreation.update(11, Array(11))
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
@@ -216,7 +216,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
// assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
// assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
- // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+ assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
+ rowWithNoNullColumns.getDecimal(10, 10, 0))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
index 478702fea6146..46daa3eb8bf80 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
@@ -73,4 +73,34 @@ class CodeFormatterSuite extends SparkFunSuite {
|}
""".stripMargin
}
+
+ testCase("if else on the same line") {
+ """
+ |class A {
+ | if (c) {duh;} else {boo;}
+ |}
+ """.stripMargin
+ }{
+ """
+ |class A {
+ | if (c) {duh;} else {boo;}
+ |}
+ """.stripMargin
+ }
+
+ testCase("function calls") {
+ """
+ |foo(
+ |a,
+ |b,
+ |c)
+ """.stripMargin
+ }{
+ """
+ |foo(
+ | a,
+ | b,
+ | c)
+ """.stripMargin
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
new file mode 100644
index 0000000000000..2d3f98dbbd3d1
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{BooleanType, DataType}
+
+/**
+ * A test suite that makes sure code generation handles expression internally states correctly.
+ */
+class CodegenExpressionCachingSuite extends SparkFunSuite {
+
+ test("GenerateUnsafeProjection should initialize expressions") {
+ // Use an Add to wrap two of them together in case we only initialize the top level expressions.
+ val expr = And(NondeterministicExpression(), NondeterministicExpression())
+ val instance = UnsafeProjection.create(Seq(expr))
+ assert(instance.apply(null).getBoolean(0) === false)
+ }
+
+ test("GenerateProjection should initialize expressions") {
+ val expr = And(NondeterministicExpression(), NondeterministicExpression())
+ val instance = GenerateProjection.generate(Seq(expr))
+ assert(instance.apply(null).getBoolean(0) === false)
+ }
+
+ test("GenerateMutableProjection should initialize expressions") {
+ val expr = And(NondeterministicExpression(), NondeterministicExpression())
+ val instance = GenerateMutableProjection.generate(Seq(expr))()
+ assert(instance.apply(null).getBoolean(0) === false)
+ }
+
+ test("GeneratePredicate should initialize expressions") {
+ val expr = And(NondeterministicExpression(), NondeterministicExpression())
+ val instance = GeneratePredicate.generate(expr)
+ assert(instance.apply(null) === false)
+ }
+
+ test("GenerateUnsafeProjection should not share expression instances") {
+ val expr1 = MutableExpression()
+ val instance1 = UnsafeProjection.create(Seq(expr1))
+ assert(instance1.apply(null).getBoolean(0) === false)
+
+ val expr2 = MutableExpression()
+ expr2.mutableState = true
+ val instance2 = UnsafeProjection.create(Seq(expr2))
+ assert(instance1.apply(null).getBoolean(0) === false)
+ assert(instance2.apply(null).getBoolean(0) === true)
+ }
+
+ test("GenerateProjection should not share expression instances") {
+ val expr1 = MutableExpression()
+ val instance1 = GenerateProjection.generate(Seq(expr1))
+ assert(instance1.apply(null).getBoolean(0) === false)
+
+ val expr2 = MutableExpression()
+ expr2.mutableState = true
+ val instance2 = GenerateProjection.generate(Seq(expr2))
+ assert(instance1.apply(null).getBoolean(0) === false)
+ assert(instance2.apply(null).getBoolean(0) === true)
+ }
+
+ test("GenerateMutableProjection should not share expression instances") {
+ val expr1 = MutableExpression()
+ val instance1 = GenerateMutableProjection.generate(Seq(expr1))()
+ assert(instance1.apply(null).getBoolean(0) === false)
+
+ val expr2 = MutableExpression()
+ expr2.mutableState = true
+ val instance2 = GenerateMutableProjection.generate(Seq(expr2))()
+ assert(instance1.apply(null).getBoolean(0) === false)
+ assert(instance2.apply(null).getBoolean(0) === true)
+ }
+
+ test("GeneratePredicate should not share expression instances") {
+ val expr1 = MutableExpression()
+ val instance1 = GeneratePredicate.generate(expr1)
+ assert(instance1.apply(null) === false)
+
+ val expr2 = MutableExpression()
+ expr2.mutableState = true
+ val instance2 = GeneratePredicate.generate(expr2)
+ assert(instance1.apply(null) === false)
+ assert(instance2.apply(null) === true)
+ }
+
+}
+
+/**
+ * An expression that's non-deterministic and doesn't support codegen.
+ */
+case class NondeterministicExpression()
+ extends LeafExpression with Nondeterministic with CodegenFallback {
+ override protected def initInternal(): Unit = { }
+ override protected def evalInternal(input: InternalRow): Any = false
+ override def nullable: Boolean = false
+ override def dataType: DataType = BooleanType
+}
+
+
+/**
+ * An expression with mutable state so we can change it freely in our test suite.
+ */
+case class MutableExpression() extends LeafExpression with CodegenFallback {
+ var mutableState: Boolean = false
+ override def eval(input: InternalRow): Any = mutableState
+
+ override def nullable: Boolean = false
+ override def dataType: DataType = BooleanType
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index fab9eb9cd4c9f..60d2bcfe13757 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -19,47 +19,48 @@ package org.apache.spark.sql.catalyst.util
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
-import java.util.{TimeZone, Calendar}
+import java.util.{Calendar, TimeZone}
import org.apache.spark.SparkFunSuite
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.sql.catalyst.util.DateTimeUtils._
class DateTimeUtilsSuite extends SparkFunSuite {
private[this] def getInUTCDays(timestamp: Long): Int = {
val tz = TimeZone.getDefault
- ((timestamp + tz.getOffset(timestamp)) / DateTimeUtils.MILLIS_PER_DAY).toInt
+ ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt
}
test("timestamp and us") {
val now = new Timestamp(System.currentTimeMillis())
now.setNanos(1000)
- val ns = DateTimeUtils.fromJavaTimestamp(now)
+ val ns = fromJavaTimestamp(now)
assert(ns % 1000000L === 1)
- assert(DateTimeUtils.toJavaTimestamp(ns) === now)
+ assert(toJavaTimestamp(ns) === now)
List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t =>
- val ts = DateTimeUtils.toJavaTimestamp(t)
- assert(DateTimeUtils.fromJavaTimestamp(ts) === t)
- assert(DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJavaTimestamp(ts)) === ts)
+ val ts = toJavaTimestamp(t)
+ assert(fromJavaTimestamp(ts) === t)
+ assert(toJavaTimestamp(fromJavaTimestamp(ts)) === ts)
}
}
test("us and julian day") {
- val (d, ns) = DateTimeUtils.toJulianDay(0)
- assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH)
- assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND)
- assert(DateTimeUtils.fromJulianDay(d, ns) == 0L)
+ val (d, ns) = toJulianDay(0)
+ assert(d === JULIAN_DAY_OF_EPOCH)
+ assert(ns === SECONDS_PER_DAY / 2 * NANOS_PER_SECOND)
+ assert(fromJulianDay(d, ns) == 0L)
val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100)
- val (d1, ns1) = DateTimeUtils.toJulianDay(DateTimeUtils.fromJavaTimestamp(t))
- val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1))
+ val (d1, ns1) = toJulianDay(fromJavaTimestamp(t))
+ val t2 = toJavaTimestamp(fromJulianDay(d1, ns1))
assert(t.equals(t2))
}
test("SPARK-6785: java date conversion before and after epoch") {
def checkFromToJavaDate(d1: Date): Unit = {
- val d2 = DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d1))
+ val d2 = toJavaDate(fromJavaDate(d1))
assert(d2.toString === d1.toString)
}
@@ -95,157 +96,156 @@ class DateTimeUtilsSuite extends SparkFunSuite {
}
test("string to date") {
- import DateTimeUtils.millisToDays
var c = Calendar.getInstance()
c.set(2015, 0, 28, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-01-28")).get ===
+ assert(stringToDate(UTF8String.fromString("2015-01-28")).get ===
millisToDays(c.getTimeInMillis))
c.set(2015, 0, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015")).get ===
+ assert(stringToDate(UTF8String.fromString("2015")).get ===
millisToDays(c.getTimeInMillis))
c = Calendar.getInstance()
c.set(2015, 2, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03")).get ===
+ assert(stringToDate(UTF8String.fromString("2015-03")).get ===
millisToDays(c.getTimeInMillis))
c = Calendar.getInstance()
c.set(2015, 2, 18, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18")).get ===
+ assert(stringToDate(UTF8String.fromString("2015-03-18")).get ===
millisToDays(c.getTimeInMillis))
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 ")).get ===
+ assert(stringToDate(UTF8String.fromString("2015-03-18 ")).get ===
millisToDays(c.getTimeInMillis))
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 123142")).get ===
+ assert(stringToDate(UTF8String.fromString("2015-03-18 123142")).get ===
millisToDays(c.getTimeInMillis))
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T123123")).get ===
+ assert(stringToDate(UTF8String.fromString("2015-03-18T123123")).get ===
millisToDays(c.getTimeInMillis))
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T")).get ===
+ assert(stringToDate(UTF8String.fromString("2015-03-18T")).get ===
millisToDays(c.getTimeInMillis))
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty)
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015/03/18")).isEmpty)
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015.03.18")).isEmpty)
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("20150318")).isEmpty)
- assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-031-8")).isEmpty)
+ assert(stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty)
+ assert(stringToDate(UTF8String.fromString("2015/03/18")).isEmpty)
+ assert(stringToDate(UTF8String.fromString("2015.03.18")).isEmpty)
+ assert(stringToDate(UTF8String.fromString("20150318")).isEmpty)
+ assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty)
}
test("string to timestamp") {
var c = Calendar.getInstance()
c.set(1969, 11, 31, 16, 0, 0)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get ===
c.getTimeInMillis * 1000)
c.set(2015, 0, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015")).get ===
c.getTimeInMillis * 1000)
c = Calendar.getInstance()
c.set(2015, 2, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015-03")).get ===
c.getTimeInMillis * 1000)
c = Calendar.getInstance()
c.set(2015, 2, 18, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18")).get ===
c.getTimeInMillis * 1000)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get ===
c.getTimeInMillis * 1000)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18T")).get ===
c.getTimeInMillis * 1000)
c = Calendar.getInstance()
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get ===
c.getTimeInMillis * 1000)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get ===
c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("GMT-13:53"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17-13:53")).get === c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get ===
c.getTimeInMillis * 1000)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get ===
c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get ===
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get ===
c.getTimeInMillis * 1000)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17-01:00")).get === c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17+07:30")).get === c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17+07:03")).get === c.getTimeInMillis * 1000)
c = Calendar.getInstance()
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18 12:03:17.123")).get === c.getTimeInMillis * 1000)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17.123")).get === c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 456)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17.456Z")).get === c.getTimeInMillis * 1000)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18 12:03:17.456Z")).get === c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17.123-1:0")).get === c.getTimeInMillis * 1000)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17.123-01:00")).get === c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000)
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17.123121+7:30")).get ===
c.getTimeInMillis * 1000 + 121)
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03:17.12312+7:30")).get ===
c.getTimeInMillis * 1000 + 120)
@@ -254,7 +254,7 @@ class DateTimeUtilsSuite extends SparkFunSuite {
c.set(Calendar.MINUTE, 12)
c.set(Calendar.SECOND, 15)
c.set(Calendar.MILLISECOND, 0)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("18:12:15")).get ===
c.getTimeInMillis * 1000)
@@ -263,7 +263,7 @@ class DateTimeUtilsSuite extends SparkFunSuite {
c.set(Calendar.MINUTE, 12)
c.set(Calendar.SECOND, 15)
c.set(Calendar.MILLISECOND, 123)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("T18:12:15.12312+7:30")).get ===
c.getTimeInMillis * 1000 + 120)
@@ -272,93 +272,130 @@ class DateTimeUtilsSuite extends SparkFunSuite {
c.set(Calendar.MINUTE, 12)
c.set(Calendar.SECOND, 15)
c.set(Calendar.MILLISECOND, 123)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("18:12:15.12312+7:30")).get ===
c.getTimeInMillis * 1000 + 120)
c = Calendar.getInstance()
c.set(2011, 4, 6, 7, 8, 9)
c.set(Calendar.MILLISECOND, 100)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("238")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("20150318")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(UTF8String.fromString("238")).isEmpty)
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty)
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty)
+ assert(stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty)
+ assert(stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty)
+ assert(stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty)
+ assert(stringToTimestamp(UTF8String.fromString("20150318")).isEmpty)
+ assert(stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty)
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty)
- assert(DateTimeUtils.stringToTimestamp(
+ assert(stringToTimestamp(
UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty)
}
test("hours") {
val c = Calendar.getInstance()
c.set(2015, 2, 18, 13, 2, 11)
- assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 13)
+ assert(getHours(c.getTimeInMillis * 1000) === 13)
c.set(2015, 12, 8, 2, 7, 9)
- assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 2)
+ assert(getHours(c.getTimeInMillis * 1000) === 2)
}
test("minutes") {
val c = Calendar.getInstance()
c.set(2015, 2, 18, 13, 2, 11)
- assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 2)
+ assert(getMinutes(c.getTimeInMillis * 1000) === 2)
c.set(2015, 2, 8, 2, 7, 9)
- assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 7)
+ assert(getMinutes(c.getTimeInMillis * 1000) === 7)
}
test("seconds") {
val c = Calendar.getInstance()
c.set(2015, 2, 18, 13, 2, 11)
- assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 11)
+ assert(getSeconds(c.getTimeInMillis * 1000) === 11)
c.set(2015, 2, 8, 2, 7, 9)
- assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 9)
+ assert(getSeconds(c.getTimeInMillis * 1000) === 9)
}
test("get day in year") {
val c = Calendar.getInstance()
c.set(2015, 2, 18, 0, 0, 0)
- assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77)
+ assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77)
c.set(2012, 2, 18, 0, 0, 0)
- assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78)
+ assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78)
}
test("get year") {
val c = Calendar.getInstance()
c.set(2015, 2, 18, 0, 0, 0)
- assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2015)
+ assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2015)
c.set(2012, 2, 18, 0, 0, 0)
- assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2012)
+ assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2012)
}
test("get quarter") {
val c = Calendar.getInstance()
c.set(2015, 2, 18, 0, 0, 0)
- assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 1)
+ assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 1)
c.set(2012, 11, 18, 0, 0, 0)
- assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 4)
+ assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 4)
}
test("get month") {
val c = Calendar.getInstance()
c.set(2015, 2, 18, 0, 0, 0)
- assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 3)
+ assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 3)
c.set(2012, 11, 18, 0, 0, 0)
- assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 12)
+ assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 12)
}
test("get day of month") {
val c = Calendar.getInstance()
c.set(2015, 2, 18, 0, 0, 0)
- assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18)
+ assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18)
c.set(2012, 11, 24, 0, 0, 0)
- assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24)
+ assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24)
+ }
+
+ test("date add months") {
+ val c1 = Calendar.getInstance()
+ c1.set(1997, 1, 28, 10, 30, 0)
+ val days1 = millisToDays(c1.getTimeInMillis)
+ val c2 = Calendar.getInstance()
+ c2.set(2000, 1, 29)
+ assert(dateAddMonths(days1, 36) === millisToDays(c2.getTimeInMillis))
+ c2.set(1996, 0, 31)
+ assert(dateAddMonths(days1, -13) === millisToDays(c2.getTimeInMillis))
+ }
+
+ test("timestamp add months") {
+ val c1 = Calendar.getInstance()
+ c1.set(1997, 1, 28, 10, 30, 0)
+ c1.set(Calendar.MILLISECOND, 0)
+ val ts1 = c1.getTimeInMillis * 1000L
+ val c2 = Calendar.getInstance()
+ c2.set(2000, 1, 29, 10, 30, 0)
+ c2.set(Calendar.MILLISECOND, 123)
+ val ts2 = c2.getTimeInMillis * 1000L
+ assert(timestampAddInterval(ts1, 36, 123000) === ts2)
+ }
+
+ test("monthsBetween") {
+ val c1 = Calendar.getInstance()
+ c1.set(1997, 1, 28, 10, 30, 0)
+ val c2 = Calendar.getInstance()
+ c2.set(1996, 9, 30, 0, 0, 0)
+ assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677)
+ c2.set(2000, 1, 28, 0, 0, 0)
+ assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36)
+ c2.set(2000, 1, 29, 0, 0, 0)
+ assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36)
+ c2.set(1996, 2, 31, 0, 0, 0)
+ assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 114ab91d10aa0..3ea0f9ed3bddd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -40,8 +40,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
-import org.apache.spark.sql.execution.datasources.CreateTableUsingAsSelect
-import org.apache.spark.sql.json.JacksonGenerator
+import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
+import org.apache.spark.sql.json.{JacksonGenerator, JSONRelation}
+import org.apache.spark.sql.sources.HadoopFsRelation
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -1546,6 +1547,21 @@ class DataFrame private[sql](
}
}
+ /**
+ * Returns a best-effort snapshot of the files that compose this DataFrame. This method simply
+ * asks each constituent BaseRelation for its respective files and takes the union of all results.
+ * Depending on the source relations, this may not find all input files. Duplicates are removed.
+ */
+ def inputFiles: Array[String] = {
+ val files: Seq[String] = logicalPlan.collect {
+ case LogicalRelation(fsBasedRelation: HadoopFsRelation) =>
+ fsBasedRelation.paths.toSeq
+ case LogicalRelation(jsonRelation: JSONRelation) =>
+ jsonRelation.path.toSeq
+ }.flatten
+ files.toSet.toArray
+ }
+
////////////////////////////////////////////////////////////////////////////
// for Python API
////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 4ec58082e7aef..2e68e358f2f1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql
+import java.{util => ju, lang => jl}
+
+import scala.collection.JavaConverters._
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._
@@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col column that defines strata
+ * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
+ * its fraction as zero.
+ * @param seed random seed
+ * @tparam T stratum type
+ * @return a new [[DataFrame]] that represents the stratified sample
+ *
+ * @since 1.5.0
+ */
+ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
+ require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
+ s"Fractions must be in [0, 1], but got $fractions.")
+ import org.apache.spark.sql.functions.{rand, udf}
+ val c = Column(col)
+ val r = rand(seed)
+ val f = udf { (stratum: Any, x: Double) =>
+ x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
+ }
+ df.filter(f(c, r))
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col column that defines strata
+ * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
+ * its fraction as zero.
+ * @param seed random seed
+ * @tparam T stratum type
+ * @return a new [[DataFrame]] that represents the stratified sample
+ *
+ * @since 1.5.0
+ */
+ def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
+ sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 40eba33f595ca..6644e85d4a037 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -229,7 +229,7 @@ private[spark] object SQLConf {
" a specific query.")
val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled",
- defaultValue = Some(false),
+ defaultValue = Some(true),
doc = "When true, use the new optimized Tungsten physical execution backend.")
val DIALECT = stringConf(
@@ -247,6 +247,13 @@ private[spark] object SQLConf {
"otherwise the schema is picked from the summary file or a random data file " +
"if no summary file is available.")
+ val PARQUET_SCHEMA_RESPECT_SUMMARIES = booleanConf("spark.sql.parquet.respectSummaryFiles",
+ defaultValue = Some(false),
+ doc = "When true, we make assumption that all part-files of Parquet are consistent with " +
+ "summary files and we will ignore them when merging schema. Otherwise, if this is " +
+ "false, which is the default, we will merge all part-files. This should be considered " +
+ "as expert-only option, and shouldn't be enabled before knowing what it means exactly.")
+
val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString",
defaultValue = Some(false),
doc = "Some other Parquet-producing systems, in particular Impala and older versions of " +
@@ -322,7 +329,7 @@ private[spark] object SQLConf {
" memory.")
val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin",
- defaultValue = Some(false),
+ defaultValue = Some(true),
doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.")
// This is only used for the thriftserver
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 454b7b91a63f5..1620fc401ba6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -114,7 +114,7 @@ private[sql] class FixedDecimalColumnBuilder(
precision: Int,
scale: Int)
extends NativeColumnBuilder(
- new FixedDecimalColumnStats,
+ new FixedDecimalColumnStats(precision, scale),
FIXED_DECIMAL(precision, scale))
// TODO (lian) Add support for array, struct and map
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 32a84b2676e07..af1a8ecca9b57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -234,14 +234,14 @@ private[sql] class BinaryColumnStats extends ColumnStats {
InternalRow(null, null, nullCount, count, sizeInBytes)
}
-private[sql] class FixedDecimalColumnStats extends ColumnStats {
+private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
protected var upper: Decimal = null
protected var lower: Decimal = null
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- val value = row.getDecimal(ordinal)
+ val value = row.getDecimal(ordinal, precision, scale)
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += FIXED_DECIMAL.defaultSize
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 2863f6c230a9d..30f8fe320db3d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -392,7 +392,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
}
override def getField(row: InternalRow, ordinal: Int): Decimal = {
- row.getDecimal(ordinal)
+ row.getDecimal(ordinal, precision, scale)
}
override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 41a0c519ba527..6bd57f010a990 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -47,7 +47,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
override def canProcessSafeRows: Boolean = true
- override def canProcessUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = {
+ // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to
+ // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to
+ // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed.
+ !newPartitioning.isInstanceOf[RangePartitioning]
+ }
/**
* Determines whether records must be defensively copied before being sent to the shuffle.
@@ -197,41 +202,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
- // True iff every child's outputPartitioning satisfies the corresponding
- // required data distribution.
- def meetsRequirements: Boolean =
- operator.requiredChildDistribution.zip(operator.children).forall {
- case (required, child) =>
- val valid = child.outputPartitioning.satisfies(required)
- logDebug(
- s"${if (valid) "Valid" else "Invalid"} distribution," +
- s"required: $required current: ${child.outputPartitioning}")
- valid
- }
-
- // True iff any of the children are incorrectly sorted.
- def needsAnySort: Boolean =
- operator.requiredChildOrdering.zip(operator.children).exists {
- case (required, child) => required.nonEmpty && required != child.outputOrdering
- }
-
- // True iff outputPartitionings of children are compatible with each other.
- // It is possible that every child satisfies its required data distribution
- // but two children have incompatible outputPartitionings. For example,
- // A dataset is range partitioned by "a.asc" (RangePartitioning) and another
- // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two
- // datasets are both clustered by "a", but these two outputPartitionings are not
- // compatible.
- // TODO: ASSUMES TRANSITIVITY?
- def compatible: Boolean =
- operator.children
- .map(_.outputPartitioning)
- .sliding(2)
- .forall {
- case Seq(a) => true
- case Seq(a, b) => a.compatibleWith(b)
- }
-
// Adds Exchange or Sort operators as required
def addOperatorsIfNecessary(
partitioning: Partitioning,
@@ -264,33 +234,26 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
addSortIfNecessary(addShuffleIfNecessary(child))
}
- if (meetsRequirements && compatible && !needsAnySort) {
- operator
- } else {
- // At least one child does not satisfies its required data distribution or
- // at least one child's outputPartitioning is not compatible with another child's
- // outputPartitioning. In this case, we need to add Exchange operators.
- val requirements =
- (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
-
- val fixedChildren = requirements.zipped.map {
- case (AllTuples, rowOrdering, child) =>
- addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
- case (ClusteredDistribution(clustering), rowOrdering, child) =>
- addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
- case (OrderedDistribution(ordering), rowOrdering, child) =>
- addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
+ val requirements =
+ (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
- case (UnspecifiedDistribution, Seq(), child) =>
- child
- case (UnspecifiedDistribution, rowOrdering, child) =>
- sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
+ val fixedChildren = requirements.zipped.map {
+ case (AllTuples, rowOrdering, child) =>
+ addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
+ case (ClusteredDistribution(clustering), rowOrdering, child) =>
+ addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
+ case (OrderedDistribution(ordering), rowOrdering, child) =>
+ addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
- case (dist, ordering, _) =>
- sys.error(s"Don't know how to ensure $dist with ordering $ordering")
- }
+ case (UnspecifiedDistribution, Seq(), child) =>
+ child
+ case (UnspecifiedDistribution, rowOrdering, child) =>
+ sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
- operator.withNewChildren(fixedChildren)
+ case (dist, ordering, _) =>
+ sys.error(s"Don't know how to ensure $dist with ordering $ordering")
}
+
+ operator.withNewChildren(fixedChildren)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 5ad4691a5ca07..d851eae3fcc71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.TaskContext
+import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -108,7 +108,7 @@ case class GeneratedAggregate(
Add(
Coalesce(currentSum :: zero :: Nil),
Cast(expr, calcType)
- ) :: currentSum :: zero :: Nil)
+ ) :: currentSum :: Nil)
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
@@ -118,45 +118,6 @@ case class GeneratedAggregate(
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
- case cs @ CombineSum(expr) =>
- val calcType =
- expr.dataType match {
- case DecimalType.Fixed(p, s) =>
- DecimalType.bounded(p + 10, s)
- case _ =>
- expr.dataType
- }
-
- val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
- val initialValue = Literal.create(null, calcType)
-
- // Coalesce avoids double calculation...
- // but really, common sub expression elimination would be better....
- val zero = Cast(Literal(0), calcType)
- // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
- // UnscaledValue will be null if and only if x is null; helps with Average on decimals
- val actualExpr = expr match {
- case UnscaledValue(e) => e
- case _ => expr
- }
- // partial sum result can be null only when no input rows present
- val updateFunction = If(
- IsNotNull(actualExpr),
- Coalesce(
- Add(
- Coalesce(currentSum :: zero :: Nil),
- Cast(expr, calcType)) :: currentSum :: zero :: Nil),
- currentSum)
-
- val result =
- expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- Cast(currentSum, cs.dataType)
- case _ => currentSum
- }
-
- AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
-
case m @ Max(expr) =>
val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
val initialValue = Literal.create(null, expr.dataType)
@@ -241,7 +202,7 @@ case class GeneratedAggregate(
val schemaSupportsUnsafe: Boolean = {
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
- UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema)
+ UnsafeProjection.canSupport(groupKeySchema)
}
child.execute().mapPartitions { iter =>
@@ -299,12 +260,14 @@ case class GeneratedAggregate(
} else if (unsafeEnabled && schemaSupportsUnsafe) {
assert(iter.hasNext, "There should be at least one row for this path")
log.info("Using Unsafe-based aggregator")
+ val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
val aggregationMap = new UnsafeFixedWidthAggregationMap(
newAggregationBuffer(EmptyRow),
aggregationBufferSchema,
groupKeySchema,
TaskContext.get.taskMemoryManager(),
1024 * 16, // initial capacity
+ pageSizeBytes,
false // disable tracking of performance metrics
)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 2dee3542d6101..a2145b185ce90 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -18,10 +18,8 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}
@@ -37,61 +35,15 @@ object SortPrefixUtils {
def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
sortOrder.dataType match {
- case StringType => PrefixComparators.STRING
- case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL
- case FloatType => PrefixComparators.FLOAT
- case DoubleType => PrefixComparators.DOUBLE
+ case StringType if sortOrder.isAscending => PrefixComparators.STRING
+ case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC
+ case BooleanType | ByteType | ShortType | IntegerType | LongType if sortOrder.isAscending =>
+ PrefixComparators.LONG
+ case BooleanType | ByteType | ShortType | IntegerType | LongType if !sortOrder.isAscending =>
+ PrefixComparators.LONG_DESC
+ case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE
+ case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC
case _ => NoOpPrefixComparator
}
}
-
- def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = {
- sortOrder.dataType match {
- case StringType => (row: InternalRow) => {
- PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String])
- }
- case BooleanType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1
- else 0
- }
- case ByteType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Byte]
- }
- case ShortType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Short]
- }
- case IntegerType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Int]
- }
- case LongType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Long]
- }
- case FloatType => (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX
- else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float])
- }
- case DoubleType => (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX
- else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double])
- }
- case _ => (row: InternalRow) => 0L
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index c808442a4849b..e5bbd0aaed0a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
- val value = row.getDecimal(i)
+ val value = row.getDecimal(i, decimal.precision, decimal.scale)
val javaBigDecimal = value.toJavaBigDecimal
// First, write out the unscaled value.
val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 306bbfec624c0..03d24a88d4ecd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -193,15 +193,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}
- def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
- aggregate.Utils.tryConvert(
- plan,
- sqlContext.conf.useSqlAggregate2,
- sqlContext.conf.codegenEnabled).isDefined
+ def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match {
+ case a: logical.Aggregate =>
+ if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) {
+ a.newAggregation.isDefined
+ } else {
+ Utils.checkInvalidAggregateFunction2(a)
+ false
+ }
+ case _ => false
}
def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
- case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
+ case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
Seq(IntegerType, LongType).contains(exprs.head.dataType) => true
@@ -217,12 +221,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case p: logical.Aggregate =>
- val converted =
- aggregate.Utils.tryConvert(
- p,
- sqlContext.conf.useSqlAggregate2,
- sqlContext.conf.codegenEnabled)
+ case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 &&
+ sqlContext.conf.codegenEnabled =>
+ val converted = p.newAggregation
converted match {
case None => Nil // Cannot convert to new aggregation code path.
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
@@ -339,8 +340,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* if necessary.
*/
def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = {
- if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) {
- execution.UnsafeExternalSort(sortExprs, global, child)
+ if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled &&
+ TungstenSort.supportsSchema(child.schema)) {
+ execution.TungstenSort(sortExprs, global, child)
} else if (sqlContext.conf.externalSortEnabled) {
execution.ExternalSort(sortExprs, global, child)
} else {
@@ -363,23 +365,27 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Sort(sortExprs, global, child) =>
getSortOperator(sortExprs, global, planLater(child)):: Nil
case logical.Project(projectList, child) =>
- execution.Project(projectList, planLater(child)) :: Nil
+ // If unsafe mode is enabled and we support these data types in Unsafe, use the
+ // Tungsten project. Otherwise, use the normal project.
+ if (sqlContext.conf.unsafeEnabled &&
+ UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) {
+ execution.TungstenProject(projectList, planLater(child)) :: Nil
+ } else {
+ execution.Project(projectList, planLater(child)) :: Nil
+ }
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case e @ logical.Expand(_, _, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case a @ logical.Aggregate(group, agg, child) => {
- val useNewAggregation =
- aggregate.Utils.tryConvert(
- a,
- sqlContext.conf.useSqlAggregate2,
- sqlContext.conf.codegenEnabled).isDefined
- if (useNewAggregation) {
+ val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled
+ if (useNewAggregation && a.newAggregation.isDefined) {
// If this logical.Aggregate can be planned to use new aggregation code path
// (i.e. it can be planned by the Strategy Aggregation), we will not use the old
// aggregation code path.
Nil
} else {
+ Utils.checkInvalidAggregateFunction2(a)
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
index 0c9082897f390..98538c462bc89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
@@ -72,8 +72,10 @@ case class Aggregate2Sort(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
if (aggregateExpressions.length == 0) {
- new GroupingIterator(
+ new FinalSortAggregationIterator(
groupingExpressions,
+ Nil,
+ Nil,
resultExpressions,
newMutableProjection,
child.output,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
index 1b89edafa8dad..2ca0cb82c1aab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -41,7 +41,8 @@ private[sql] abstract class SortAggregationIterator(
///////////////////////////////////////////////////////////////////////////
protected val aggregateFunctions: Array[AggregateFunction2] = {
- var bufferOffset = initialBufferOffset
+ var mutableBufferOffset = 0
+ var inputBufferOffset: Int = initialInputBufferOffset
val functions = new Array[AggregateFunction2](aggregateExpressions.length)
var i = 0
while (i < aggregateExpressions.length) {
@@ -54,13 +55,18 @@ private[sql] abstract class SortAggregationIterator(
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
BindReferences.bindReference(func, inputAttributes)
- case _ => func
+ case _ =>
+ // We only need to set inputBufferOffset for aggregate functions with mode
+ // PartialMerge and Final.
+ func.inputBufferOffset = inputBufferOffset
+ inputBufferOffset += func.bufferSchema.length
+ func
}
- // Set bufferOffset for this function. It is important that setting bufferOffset
- // happens after all potential bindReference operations because bindReference
- // will create a new instance of the function.
- funcWithBoundReferences.bufferOffset = bufferOffset
- bufferOffset += funcWithBoundReferences.bufferSchema.length
+ // Set mutableBufferOffset for this function. It is important that setting
+ // mutableBufferOffset happens after all potential bindReference operations
+ // because bindReference will create a new instance of the function.
+ funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset
+ mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
functions(i) = funcWithBoundReferences
i += 1
}
@@ -97,25 +103,24 @@ private[sql] abstract class SortAggregationIterator(
// The number of elements of the underlying buffer of this operator.
// All aggregate functions are sharing this underlying buffer and they find their
// buffer values through bufferOffset.
- var size = initialBufferOffset
- var i = 0
- while (i < aggregateFunctions.length) {
- size += aggregateFunctions(i).bufferSchema.length
- i += 1
- }
- new GenericMutableRow(size)
+ // var size = 0
+ // var i = 0
+ // while (i < aggregateFunctions.length) {
+ // size += aggregateFunctions(i).bufferSchema.length
+ // i += 1
+ // }
+ new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum)
}
protected val joinedRow = new JoinedRow
- protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp)
-
// This projection is used to initialize buffer values for all AlgebraicAggregates.
protected val algebraicInitialProjection = {
- val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ val initExpressions = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.initialValues
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
}
+
newMutableProjection(initExpressions, Nil)().target(buffer)
}
@@ -132,10 +137,6 @@ private[sql] abstract class SortAggregationIterator(
// Indicates if we has new group of rows to process.
protected var hasNewGroup: Boolean = true
- ///////////////////////////////////////////////////////////////////////////
- // Private methods
- ///////////////////////////////////////////////////////////////////////////
-
/** Initializes buffer values for all aggregate functions. */
protected def initializeBuffer(): Unit = {
algebraicInitialProjection(EmptyRow)
@@ -160,6 +161,10 @@ private[sql] abstract class SortAggregationIterator(
}
}
+ ///////////////////////////////////////////////////////////////////////////
+ // Private methods
+ ///////////////////////////////////////////////////////////////////////////
+
/** Processes rows in the current group. It will stop when it find a new group. */
private def processCurrentGroup(): Unit = {
currentGroupingKey = nextGroupingKey
@@ -218,10 +223,13 @@ private[sql] abstract class SortAggregationIterator(
// Methods that need to be implemented
///////////////////////////////////////////////////////////////////////////
- protected def initialBufferOffset: Int
+ /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */
+ protected def initialInputBufferOffset: Int
+ /** The function used to process an input row. */
protected def processRow(row: InternalRow): Unit
+ /** The function used to generate the result row. */
protected def generateOutput(): InternalRow
///////////////////////////////////////////////////////////////////////////
@@ -231,37 +239,6 @@ private[sql] abstract class SortAggregationIterator(
initialize()
}
-/**
- * An iterator only used to group input rows according to values of `groupingExpressions`.
- * It assumes that input rows are already grouped by values of `groupingExpressions`.
- */
-class GroupingIterator(
- groupingExpressions: Seq[NamedExpression],
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- Nil,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- private val resultProjection =
- newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))()
-
- override protected def initialBufferOffset: Int = 0
-
- override protected def processRow(row: InternalRow): Unit = {
- // Since we only do grouping, there is nothing to do at here.
- }
-
- override protected def generateOutput(): InternalRow = {
- resultProjection(currentGroupingKey)
- }
-}
-
/**
* An iterator used to do partial aggregations (for those aggregate functions with mode Partial).
* It assumes that input rows are already grouped by values of `groupingExpressions`.
@@ -291,7 +268,7 @@ class PartialSortAggregationIterator(
newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
}
- override protected def initialBufferOffset: Int = 0
+ override protected def initialInputBufferOffset: Int = 0
override protected def processRow(row: InternalRow): Unit = {
// Process all algebraic aggregate functions.
@@ -318,11 +295,7 @@ class PartialSortAggregationIterator(
* |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
*
* The format of its internal buffer is:
- * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN|
- * Every placeholder is for a grouping expression.
- * The actual buffers are stored after placeholderN.
- * The reason that we have placeholders at here is to make our underlying buffer have the same
- * length with a input row.
+ * |aggregationBuffer1|...|aggregationBufferN|
*
* The format of its output rows is:
* |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
@@ -340,33 +313,21 @@ class PartialMergeSortAggregationIterator(
inputAttributes,
inputIter) {
- private val placeholderAttributes =
- Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
-
// This projection is used to merge buffer values for all AlgebraicAggregates.
private val algebraicMergeProjection = {
- val bufferSchemata =
- placeholderAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
- placeholderAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ val mergeInputSchema =
+ aggregateFunctions.flatMap(_.bufferAttributes) ++
+ groupingExpressions.map(_.toAttribute) ++
+ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
}
- newMutableProjection(mergeExpressions, bufferSchemata)()
+ newMutableProjection(mergeExpressions, mergeInputSchema)()
}
- // This projection is used to extract aggregation buffers from the underlying buffer.
- // We need it because the underlying buffer has placeholders at its beginning.
- private val extractsBufferValues = {
- val expressions = aggregateFunctions.flatMap {
- case agg => agg.bufferAttributes
- }
-
- newMutableProjection(expressions, inputAttributes)()
- }
-
- override protected def initialBufferOffset: Int = groupingExpressions.length
+ override protected def initialInputBufferOffset: Int = groupingExpressions.length
override protected def processRow(row: InternalRow): Unit = {
// Process all algebraic aggregate functions.
@@ -381,7 +342,7 @@ class PartialMergeSortAggregationIterator(
override protected def generateOutput(): InternalRow = {
// We output grouping expressions and aggregation buffers.
- joinedRow(currentGroupingKey, extractsBufferValues(buffer))
+ joinedRow(currentGroupingKey, buffer).copy()
}
}
@@ -393,11 +354,7 @@ class PartialMergeSortAggregationIterator(
* |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
*
* The format of its internal buffer is:
- * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN|
- * Every placeholder is for a grouping expression.
- * The actual buffers are stored after placeholderN.
- * The reason that we have placeholders at here is to make our underlying buffer have the same
- * length with a input row.
+ * |aggregationBuffer1|...|aggregationBufferN|
*
* The format of its output rows is represented by the schema of `resultExpressions`.
*/
@@ -425,27 +382,23 @@ class FinalSortAggregationIterator(
newMutableProjection(
resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
- private val offsetAttributes =
- Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
-
// This projection is used to merge buffer values for all AlgebraicAggregates.
private val algebraicMergeProjection = {
- val bufferSchemata =
- offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
- offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ val mergeInputSchema =
+ aggregateFunctions.flatMap(_.bufferAttributes) ++
+ groupingExpressions.map(_.toAttribute) ++
+ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
}
- newMutableProjection(mergeExpressions, bufferSchemata)()
+ newMutableProjection(mergeExpressions, mergeInputSchema)()
}
// This projection is used to evaluate all AlgebraicAggregates.
private val algebraicEvalProjection = {
- val bufferSchemata =
- offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
- offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
val evalExpressions = aggregateFunctions.map {
case ae: AlgebraicAggregate => ae.evaluateExpression
case agg: AggregateFunction2 => NoOp
@@ -454,7 +407,7 @@ class FinalSortAggregationIterator(
newMutableProjection(evalExpressions, bufferSchemata)()
}
- override protected def initialBufferOffset: Int = groupingExpressions.length
+ override protected def initialInputBufferOffset: Int = groupingExpressions.length
override def initialize(): Unit = {
if (inputIter.hasNext) {
@@ -471,7 +424,10 @@ class FinalSortAggregationIterator(
// Right now, the buffer only contains initial buffer values. Because
// merging two buffers with initial values will generate a row that
// still store initial values. We set the currentRow as the copy of the current buffer.
- val currentRow = buffer.copy()
+ // Because input aggregation buffer has initialInputBufferOffset extra values at the
+ // beginning, we create a dummy row for this part.
+ val currentRow =
+ joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
nextGroupingKey = groupGenerator(currentRow).copy()
firstRowInNextGroup = currentRow
} else {
@@ -518,18 +474,15 @@ class FinalSortAggregationIterator(
* Final mode.
*
* The format of its internal buffer is:
- * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)|
- * The first N placeholders represent slots of grouping expressions.
- * Then, next M placeholders represent slots of col1 to colM.
+ * |aggregationBuffer1|...|aggregationBuffer(N+M)|
* For aggregation buffers, first N aggregation buffers are used by N aggregate functions with
* mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode
- * Complete. The reason that we have placeholders at here is to make our underlying buffer
- * have the same length with a input row.
+ * Complete.
*
* The format of its output rows is represented by the schema of `resultExpressions`.
*/
class FinalAndCompleteSortAggregationIterator(
- override protected val initialBufferOffset: Int,
+ override protected val initialInputBufferOffset: Int,
groupingExpressions: Seq[NamedExpression],
finalAggregateExpressions: Seq[AggregateExpression2],
finalAggregateAttributes: Seq[Attribute],
@@ -561,9 +514,6 @@ class FinalAndCompleteSortAggregationIterator(
newMutableProjection(resultExpressions, inputSchema)()
}
- private val offsetAttributes =
- Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
-
// All aggregate functions with mode Final.
private val finalAggregateFunctions: Array[AggregateFunction2] = {
val functions = new Array[AggregateFunction2](finalAggregateExpressions.length)
@@ -601,38 +551,38 @@ class FinalAndCompleteSortAggregationIterator(
// This projection is used to merge buffer values for all AlgebraicAggregates with mode
// Final.
private val finalAlgebraicMergeProjection = {
- val numCompleteOffsetAttributes =
- completeAggregateFunctions.map(_.bufferAttributes.length).sum
- val completeOffsetAttributes =
- Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)())
- val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp)
-
- val bufferSchemata =
- offsetAttributes ++ finalAggregateFunctions.flatMap(_.bufferAttributes) ++
- completeOffsetAttributes ++ offsetAttributes ++
- finalAggregateFunctions.flatMap(_.cloneBufferAttributes) ++ completeOffsetAttributes
+ // The first initialInputBufferOffset values of the input aggregation buffer is
+ // for grouping expressions and distinct columns.
+ val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset)
+
+ val completeOffsetExpressions =
+ Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+
+ val mergeInputSchema =
+ finalAggregateFunctions.flatMap(_.bufferAttributes) ++
+ completeAggregateFunctions.flatMap(_.bufferAttributes) ++
+ groupingAttributesAndDistinctColumns ++
+ finalAggregateFunctions.flatMap(_.cloneBufferAttributes)
val mergeExpressions =
- placeholderExpressions ++ finalAggregateFunctions.flatMap {
+ finalAggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
} ++ completeOffsetExpressions
-
- newMutableProjection(mergeExpressions, bufferSchemata)()
+ newMutableProjection(mergeExpressions, mergeInputSchema)()
}
// This projection is used to update buffer values for all AlgebraicAggregates with mode
// Complete.
private val completeAlgebraicUpdateProjection = {
- val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum
- val finalOffsetAttributes =
- Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)())
- val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp)
+ // We do not touch buffer values of aggregate functions with the Final mode.
+ val finalOffsetExpressions =
+ Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
val bufferSchema =
- offsetAttributes ++ finalOffsetAttributes ++
+ finalAggregateFunctions.flatMap(_.bufferAttributes) ++
completeAggregateFunctions.flatMap(_.bufferAttributes)
val updateExpressions =
- placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.updateExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
}
@@ -641,9 +591,7 @@ class FinalAndCompleteSortAggregationIterator(
// This projection is used to evaluate all AlgebraicAggregates.
private val algebraicEvalProjection = {
- val bufferSchemata =
- offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
- offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
val evalExpressions = aggregateFunctions.map {
case ae: AlgebraicAggregate => ae.evaluateExpression
case agg: AggregateFunction2 => NoOp
@@ -667,7 +615,10 @@ class FinalAndCompleteSortAggregationIterator(
// Right now, the buffer only contains initial buffer values. Because
// merging two buffers with initial values will generate a row that
// still store initial values. We set the currentRow as the copy of the current buffer.
- val currentRow = buffer.copy()
+ // Because input aggregation buffer has initialInputBufferOffset extra values at the
+ // beginning, we create a dummy row for this part.
+ val currentRow =
+ joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
nextGroupingKey = groupGenerator(currentRow).copy()
firstRowInNextGroup = currentRow
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 073c45ae2f9f2..cc54319171bdb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -184,7 +184,7 @@ private[sql] case class ScalaUDAF(
bufferSchema,
bufferValuesToCatalystConverters,
bufferValuesToScalaConverters,
- bufferOffset,
+ inputBufferOffset,
null)
lazy val mutableAggregateBuffer: MutableAggregationBufferImpl =
@@ -192,9 +192,16 @@ private[sql] case class ScalaUDAF(
bufferSchema,
bufferValuesToCatalystConverters,
bufferValuesToScalaConverters,
- bufferOffset,
+ mutableBufferOffset,
null)
+ lazy val evalAggregateBuffer: InputAggregationBuffer =
+ new InputAggregationBuffer(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableBufferOffset,
+ null)
override def initialize(buffer: MutableRow): Unit = {
mutableAggregateBuffer.underlyingBuffer = buffer
@@ -217,10 +224,10 @@ private[sql] case class ScalaUDAF(
udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
}
- override def eval(buffer: InternalRow = null): Any = {
- inputAggregateBuffer.underlyingInputBuffer = buffer
+ override def eval(buffer: InternalRow): Any = {
+ evalAggregateBuffer.underlyingInputBuffer = buffer
- udaf.evaluate(inputAggregateBuffer)
+ udaf.evaluate(evalAggregateBuffer)
}
override def toString: String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 5bbe6c162ff4b..03635baae4a5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -29,150 +29,6 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
- // Right now, we do not support complex types in the grouping key schema.
- private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
- val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
- case array: ArrayType => true
- case map: MapType => true
- case struct: StructType => true
- case _ => false
- }
-
- !hasComplexTypes
- }
-
- private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
- case p: Aggregate if supportsGroupingKeySchema(p) =>
- val converted = p.transformExpressionsDown {
- case expressions.Average(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Average(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Count(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- // We do not support multiple COUNT DISTINCT columns for now.
- case expressions.CountDistinct(children) if children.length == 1 =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(children.head),
- mode = aggregate.Complete,
- isDistinct = true)
-
- case expressions.First(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.First(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Last(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Last(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Max(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Max(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Min(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Min(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Sum(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.SumDistinct(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = true)
- }
- // Check if there is any expressions.AggregateExpression1 left.
- // If so, we cannot convert this plan.
- val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
- // For every expressions, check if it contains AggregateExpression1.
- expr.find {
- case agg: expressions.AggregateExpression1 => true
- case other => false
- }.isDefined
- }
-
- // Check if there are multiple distinct columns.
- val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg
- }
- }.toSet.toSeq
- val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
- val hasMultipleDistinctColumnSets =
- if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
- true
- } else {
- false
- }
-
- if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
-
- case other => None
- }
-
- private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
- // If the plan cannot be converted, we will do a final round check to if the original
- // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
- // we need to throw an exception.
- val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg.aggregateFunction
- }
- }.distinct
- if (aggregateFunction2s.nonEmpty) {
- // For functions implemented based on the new interface, prepare a list of function names.
- val invalidFunctions = {
- if (aggregateFunction2s.length > 1) {
- s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
- s"and ${aggregateFunction2s.head.nodeName} are"
- } else {
- s"${aggregateFunction2s.head.nodeName} is"
- }
- }
- val errorMessage =
- s"${invalidFunctions} implemented based on the new Aggregate Function " +
- s"interface and it cannot be used with functions implemented based on " +
- s"the old Aggregate Function interface."
- throw new AnalysisException(errorMessage)
- }
- }
-
- def tryConvert(
- plan: LogicalPlan,
- useNewAggregation: Boolean,
- codeGenEnabled: Boolean): Option[Aggregate] = plan match {
- case p: Aggregate if useNewAggregation && codeGenEnabled =>
- val converted = tryConvert(p)
- if (converted.isDefined) {
- converted
- } else {
- checkInvalidAggregateFunction2(p)
- None
- }
- case p: Aggregate =>
- checkInvalidAggregateFunction2(p)
- None
- case other => None
- }
-
def planAggregateWithoutDistinct(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[AggregateExpression2],
@@ -292,8 +148,8 @@ object Utils {
AggregateExpression2(aggregateFunction, PartialMerge, false)
}
val partialMergeAggregateAttributes =
- partialMergeAggregateExpressions.map {
- expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ partialMergeAggregateExpressions.flatMap { agg =>
+ agg.aggregateFunction.bufferAttributes
}
val partialMergeAggregate =
Aggregate2Sort(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index fe429d862a0a3..2294a670c735f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -49,6 +49,31 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}
+
+/**
+ * A variant of [[Project]] that returns [[UnsafeRow]]s.
+ */
+case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
+
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
+
+ override def output: Seq[Attribute] = projectList.map(_.toAttribute)
+
+ protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
+ this.transformAllExpressions {
+ case CreateStruct(children) => CreateStructUnsafe(children)
+ case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
+ }
+ val project = UnsafeProjection.create(projectList, child.output)
+ iter.map(project)
+ }
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}
+
+
/**
* :: DeveloperApi ::
*/
@@ -195,137 +220,6 @@ case class TakeOrderedAndProject(
override def outputOrdering: Seq[SortOrder] = sortOrder
}
-/**
- * :: DeveloperApi ::
- * Performs a sort on-heap.
- * @param global when true performs a global sort of all partitions by shuffling the data first
- * if necessary.
- */
-@DeveloperApi
-case class Sort(
- sortOrder: Seq[SortOrder],
- global: Boolean,
- child: SparkPlan)
- extends UnaryNode {
- override def requiredChildDistribution: Seq[Distribution] =
- if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
- child.execute().mapPartitions( { iterator =>
- val ordering = newOrdering(sortOrder, child.output)
- iterator.map(_.copy()).toArray.sorted(ordering).iterator
- }, preservesPartitioning = true)
- }
-
- override def output: Seq[Attribute] = child.output
-
- override def outputOrdering: Seq[SortOrder] = sortOrder
-}
-
-/**
- * :: DeveloperApi ::
- * Performs a sort, spilling to disk as needed.
- * @param global when true performs a global sort of all partitions by shuffling the data first
- * if necessary.
- */
-@DeveloperApi
-case class ExternalSort(
- sortOrder: Seq[SortOrder],
- global: Boolean,
- child: SparkPlan)
- extends UnaryNode {
-
- override def requiredChildDistribution: Seq[Distribution] =
- if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
- child.execute().mapPartitions( { iterator =>
- val ordering = newOrdering(sortOrder, child.output)
- val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering))
- sorter.insertAll(iterator.map(r => (r.copy, null)))
- val baseIterator = sorter.iterator.map(_._1)
- // TODO(marmbrus): The complex type signature below thwarts inference for no reason.
- CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
- }, preservesPartitioning = true)
- }
-
- override def output: Seq[Attribute] = child.output
-
- override def outputOrdering: Seq[SortOrder] = sortOrder
-}
-
-/**
- * :: DeveloperApi ::
- * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of
- * Project Tungsten).
- *
- * @param global when true performs a global sort of all partitions by shuffling the data first
- * if necessary.
- * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
- * spill every `frequency` records.
- */
-@DeveloperApi
-case class UnsafeExternalSort(
- sortOrder: Seq[SortOrder],
- global: Boolean,
- child: SparkPlan,
- testSpillFrequency: Int = 0)
- extends UnaryNode {
-
- private[this] val schema: StructType = child.schema
-
- override def requiredChildDistribution: Seq[Distribution] =
- if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
- assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
- def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
- val ordering = newOrdering(sortOrder, child.output)
- val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output)
- // Hack until we generate separate comparator implementations for ascending vs. descending
- // (or choose to codegen them):
- val prefixComparator = {
- val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression)
- if (sortOrder.head.direction == Descending) {
- new PrefixComparator {
- override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2)
- }
- } else {
- comp
- }
- }
- val prefixComputer = {
- val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression)
- new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = prefixComputer(row)
- }
- }
- val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
- if (testSpillFrequency > 0) {
- sorter.setTestSpillFrequency(testSpillFrequency)
- }
- sorter.sort(iterator)
- }
- child.execute().mapPartitions(doSort, preservesPartitioning = true)
- }
-
- override def output: Seq[Attribute] = child.output
-
- override def outputOrdering: Seq[SortOrder] = sortOrder
-
- override def outputsUnsafeRows: Boolean = true
-}
-
-@DeveloperApi
-object UnsafeExternalSort {
- /**
- * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise.
- */
- def supportsSchema(schema: StructType): Boolean = {
- UnsafeExternalRowSorter.supportsSchema(schema)
- }
-}
-
/**
* :: DeveloperApi ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index e73b3704d4dfe..0cdb407ad57b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -308,7 +308,7 @@ private[sql] object ResolvedDataSource {
mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
- if (data.schema.map(_.dataType).exists(_.isInstanceOf[IntervalType])) {
+ if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
throw new AnalysisException("Cannot save interval data type into external storage.")
}
val clazz: Class[_] = lookupDataSource(provider)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index aeeb0e45270dd..f26f41fb75d57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -158,8 +158,8 @@ package object debug {
case (row: InternalRow, StructType(fields)) =>
row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
- case (s: Seq[_], ArrayType(elemType, _)) =>
- s.foreach(typeCheck(_, elemType))
+ case (a: ArrayData, ArrayType(elemType, _)) =>
+ a.toArray().foreach(typeCheck(_, elemType))
case (m: Map[_, _], MapType(keyType, valueType, _)) =>
m.keys.foreach(typeCheck(_, keyType))
m.values.foreach(typeCheck(_, valueType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala
deleted file mode 100644
index 568b7ac2c5987..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala
+++ /dev/null
@@ -1,23 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-/**
- * Package containing expressions that are specific to Spark runtime.
- */
-package object expressions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index abaa4a6ce86a2..624efc1b1d734 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -62,7 +62,7 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
- val hashed = buildHashRelation(input.iterator)
+ val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index c9d1a880f4ef4..77e7fe71009b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
- val hashed = buildHashRelation(input.iterator)
+ val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index f71c0ce352904..a60593911f94f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash(
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
protected override def doExecute(): RDD[InternalRow] = {
- val buildIter = right.execute().map(_.copy()).collect().toIterator
+ val input = right.execute().map(_.copy()).collect()
if (condition.isEmpty) {
- val hashSet = buildKeyHashSet(buildIter)
+ val hashSet = buildKeyHashSet(input.toIterator)
val broadcastedRelation = sparkContext.broadcast(hashSet)
left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
- val hashRelation = buildHashRelation(buildIter)
+ val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size)
val broadcastedRelation = sparkContext.broadcast(hashRelation)
left.execute().mapPartitions { streamIter =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 700636966f8be..83b726a8e2897 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -47,13 +47,11 @@ case class BroadcastNestedLoopJoin(
override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
override def canProcessUnsafeRows: Boolean = true
- @transient private[this] lazy val resultProjection: Projection = {
+ @transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
if (outputsUnsafeRows) {
UnsafeProjection.create(schema)
} else {
- new Projection {
- override def apply(r: InternalRow): InternalRow = r
- }
+ identity[InternalRow]
}
}
@@ -96,7 +94,6 @@ case class BroadcastNestedLoopJoin(
var streamRowMatched = false
while (i < broadcastedRelation.value.size) {
- // TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
@@ -135,17 +132,26 @@ case class BroadcastNestedLoopJoin(
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val rel = broadcastedRelation.value
- while (i < rel.length) {
- if (!allIncludedBroadcastTuples.contains(i)) {
- (joinType, buildSide) match {
- case (RightOuter | FullOuter, BuildRight) =>
- buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
- case (LeftOuter | FullOuter, BuildLeft) =>
- buf += resultProjection(new JoinedRow(rel(i), rightNulls))
- case _ =>
+ (joinType, buildSide) match {
+ case (RightOuter | FullOuter, BuildRight) =>
+ val joinedRow = new JoinedRow
+ joinedRow.withLeft(leftNulls)
+ while (i < rel.length) {
+ if (!allIncludedBroadcastTuples.contains(i)) {
+ buf += resultProjection(joinedRow.withRight(rel(i))).copy()
+ }
+ i += 1
}
- }
- i += 1
+ case (LeftOuter | FullOuter, BuildLeft) =>
+ val joinedRow = new JoinedRow
+ joinedRow.withRight(rightNulls)
+ while (i < rel.length) {
+ if (!allIncludedBroadcastTuples.contains(i)) {
+ buf += resultProjection(joinedRow.withLeft(rel(i))).copy()
+ }
+ i += 1
+ }
+ case _ =>
}
buf.toSeq
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 46ab5b0d1cc6d..6b3d1652923fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.util.collection.CompactBuffer
trait HashJoin {
@@ -44,16 +43,24 @@ trait HashJoin {
override def output: Seq[Attribute] = left.output ++ right.output
- protected[this] def supportUnsafe: Boolean = {
+ protected[this] def isUnsafeMode: Boolean = {
(self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}
- override def outputsUnsafeRows: Boolean = supportUnsafe
- override def canProcessUnsafeRows: Boolean = supportUnsafe
+ override def outputsUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessSafeRows: Boolean = !isUnsafeMode
+
+ @transient protected lazy val buildSideKeyGenerator: Projection =
+ if (isUnsafeMode) {
+ UnsafeProjection.create(buildKeys, buildPlan.output)
+ } else {
+ newMutableProjection(buildKeys, buildPlan.output)()
+ }
@transient protected lazy val streamSideKeyGenerator: Projection =
- if (supportUnsafe) {
+ if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newMutableProjection(streamedKeys, streamedPlan.output)()
@@ -65,18 +72,16 @@ trait HashJoin {
{
new Iterator[InternalRow] {
private[this] var currentStreamedRow: InternalRow = _
- private[this] var currentHashMatches: CompactBuffer[InternalRow] = _
+ private[this] var currentHashMatches: Seq[InternalRow] = _
private[this] var currentMatchPosition: Int = -1
// Mutable per row objects.
private[this] val joinRow = new JoinedRow
- private[this] val resultProjection: Projection = {
- if (supportUnsafe) {
+ private[this] val resultProjection: (InternalRow) => InternalRow = {
+ if (isUnsafeMode) {
UnsafeProjection.create(self.schema)
} else {
- new Projection {
- override def apply(r: InternalRow): InternalRow = r
- }
+ identity[InternalRow]
}
}
@@ -122,12 +127,4 @@ trait HashJoin {
}
}
}
-
- protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
- if (supportUnsafe) {
- UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
- } else {
- HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 6bf2f82954046..7e671e7914f1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -75,30 +75,36 @@ trait HashOuterJoin {
s"HashOuterJoin should not take $x as the JoinType")
}
- protected[this] def supportUnsafe: Boolean = {
+ protected[this] def isUnsafeMode: Boolean = {
(self.codegenEnabled && joinType != FullOuter
&& UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}
- override def outputsUnsafeRows: Boolean = supportUnsafe
- override def canProcessUnsafeRows: Boolean = supportUnsafe
+ override def outputsUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessSafeRows: Boolean = !isUnsafeMode
- protected[this] def streamedKeyGenerator(): Projection = {
- if (supportUnsafe) {
+ @transient protected lazy val buildKeyGenerator: Projection =
+ if (isUnsafeMode) {
+ UnsafeProjection.create(buildKeys, buildPlan.output)
+ } else {
+ newMutableProjection(buildKeys, buildPlan.output)()
+ }
+
+ @transient protected[this] lazy val streamedKeyGenerator: Projection = {
+ if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newProjection(streamedKeys, streamedPlan.output)
}
}
- @transient private[this] lazy val resultProjection: Projection = {
- if (supportUnsafe) {
+ @transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
+ if (isUnsafeMode) {
UnsafeProjection.create(self.schema)
} else {
- new Projection {
- override def apply(r: InternalRow): InternalRow = r
- }
+ identity[InternalRow]
}
}
@@ -230,12 +236,4 @@ trait HashOuterJoin {
hashTable
}
-
- protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
- if (supportUnsafe) {
- UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
- } else {
- HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 7f49264d40354..97fde8f975bfd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -35,11 +35,13 @@ trait HashSemiJoin {
protected[this] def supportUnsafe: Boolean = {
(self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
&& UnsafeProjection.canSupport(rightKeys)
- && UnsafeProjection.canSupport(left.schema))
+ && UnsafeProjection.canSupport(left.schema)
+ && UnsafeProjection.canSupport(right.schema))
}
- override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows
+ override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
+ override def canProcessSafeRows: Boolean = !supportUnsafe
@transient protected lazy val leftKeyGenerator: Projection =
if (supportUnsafe) {
@@ -87,14 +89,6 @@ trait HashSemiJoin {
})
}
- protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
- if (supportUnsafe) {
- UnsafeHashedRelation(buildIter, rightKeys, right)
- } else {
- HashedRelation(buildIter, newProjection(rightKeys, right.output))
- }
- }
-
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation): Iterator[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 8d5731afd59b8..f88a45f48aee9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,12 +18,16 @@
package org.apache.spark.sql.execution.joins
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}
+import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.unsafe.map.BytesToBytesMap
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.util.collection.CompactBuffer
@@ -32,7 +36,7 @@ import org.apache.spark.util.collection.CompactBuffer
* object.
*/
private[joins] sealed trait HashedRelation {
- def get(key: InternalRow): CompactBuffer[InternalRow]
+ def get(key: InternalRow): Seq[InternalRow]
// This is a helper method to implement Externalizable, and is used by
// GeneralHashedRelation and UniqueKeyHashedRelation
@@ -59,9 +63,9 @@ private[joins] final class GeneralHashedRelation(
private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
extends HashedRelation with Externalizable {
- def this() = this(null) // Needed for serialization
+ private def this() = this(null) // Needed for serialization
- override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key)
+ override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)
override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
@@ -81,9 +85,9 @@ private[joins]
final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
extends HashedRelation with Externalizable {
- def this() = this(null) // Needed for serialization
+ private def this() = this(null) // Needed for serialization
- override def get(key: InternalRow): CompactBuffer[InternalRow] = {
+ override def get(key: InternalRow): Seq[InternalRow] = {
val v = hashTable.get(key)
if (v eq null) null else CompactBuffer(v)
}
@@ -109,6 +113,10 @@ private[joins] object HashedRelation {
keyGenerator: Projection,
sizeEstimate: Int = 64): HashedRelation = {
+ if (keyGenerator.isInstanceOf[UnsafeProjection]) {
+ return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
+ }
+
// TODO: Use Spark's HashMap implementation.
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
var currentRow: InternalRow = null
@@ -149,31 +157,140 @@ private[joins] object HashedRelation {
}
}
-
/**
- * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a
- * sequence of values.
+ * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
+ * into a sequence of values.
+ *
+ * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
+ * BytesToBytesMap for better memory performance (multiple values for the same are stored as a
+ * continuous byte array.
*
- * TODO(davies): use BytesToBytesMap
+ * It's serialized in the following format:
+ * [number of keys]
+ * [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
+ * ...
+ *
+ * All the values are serialized as following:
+ * [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
+ * ...
*/
private[joins] final class UnsafeHashedRelation(
private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
extends HashedRelation with Externalizable {
- def this() = this(null) // Needed for serialization
+ private[joins] def this() = this(null) // Needed for serialization
+
+ // Use BytesToBytesMap in executor for better performance (it's created when deserialization)
+ @transient private[this] var binaryMap: BytesToBytesMap = _
- override def get(key: InternalRow): CompactBuffer[InternalRow] = {
+ override def get(key: InternalRow): Seq[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
- // Thanks to type eraser
- hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
+
+ if (binaryMap != null) {
+ // Used in Broadcast join
+ val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes)
+ if (loc.isDefined) {
+ val buffer = CompactBuffer[UnsafeRow]()
+
+ val base = loc.getValueAddress.getBaseObject
+ var offset = loc.getValueAddress.getBaseOffset
+ val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
+ while (offset < last) {
+ val numFields = PlatformDependent.UNSAFE.getInt(base, offset)
+ val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4)
+ offset += 8
+
+ val row = new UnsafeRow
+ row.pointTo(base, offset, numFields, sizeInBytes)
+ buffer += row
+ offset += sizeInBytes
+ }
+ buffer
+ } else {
+ null
+ }
+
+ } else {
+ // Use the JavaHashMap in Local mode or ShuffleHashJoin
+ hashTable.get(unsafeKey)
+ }
}
override def writeExternal(out: ObjectOutput): Unit = {
- writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ out.writeInt(hashTable.size())
+
+ val iter = hashTable.entrySet().iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val key = entry.getKey
+ val values = entry.getValue
+
+ // write all the values as single byte array
+ var totalSize = 0L
+ var i = 0
+ while (i < values.length) {
+ totalSize += values(i).getSizeInBytes + 4 + 4
+ i += 1
+ }
+ assert(totalSize < Integer.MAX_VALUE, "values are too big")
+
+ // [key size] [values size] [key bytes] [values bytes]
+ out.writeInt(key.getSizeInBytes)
+ out.writeInt(totalSize.toInt)
+ out.write(key.getBytes)
+ i = 0
+ while (i < values.length) {
+ // [num of fields] [num of bytes] [row bytes]
+ // write the integer in native order, so they can be read by UNSAFE.getInt()
+ if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
+ out.writeInt(values(i).numFields())
+ out.writeInt(values(i).getSizeInBytes)
+ } else {
+ out.writeInt(Integer.reverseBytes(values(i).numFields()))
+ out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
+ }
+ out.write(values(i).getBytes)
+ i += 1
+ }
+ }
}
override def readExternal(in: ObjectInput): Unit = {
- hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ val nKeys = in.readInt()
+ // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
+ val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+
+ val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+ .getSizeAsBytes("spark.buffer.pageSize", "64m")
+
+ binaryMap = new BytesToBytesMap(
+ memoryManager,
+ nKeys * 2, // reduce hash collision
+ pageSizeBytes)
+
+ var i = 0
+ var keyBuffer = new Array[Byte](1024)
+ var valuesBuffer = new Array[Byte](1024)
+ while (i < nKeys) {
+ val keySize = in.readInt()
+ val valuesSize = in.readInt()
+ if (keySize > keyBuffer.size) {
+ keyBuffer = new Array[Byte](keySize)
+ }
+ in.readFully(keyBuffer, 0, keySize)
+ if (valuesSize > valuesBuffer.size) {
+ valuesBuffer = new Array[Byte](valuesSize)
+ }
+ in.readFully(valuesBuffer, 0, valuesSize)
+
+ // put it into binary map
+ val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
+ assert(!loc.isDefined, "Duplicated key found!")
+ loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
+ valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
+ i += 1
+ }
}
}
@@ -181,33 +298,14 @@ private[joins] object UnsafeHashedRelation {
def apply(
input: Iterator[InternalRow],
- buildKeys: Seq[Expression],
- buildPlan: SparkPlan,
- sizeEstimate: Int = 64): HashedRelation = {
- val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output))
- apply(input, boundedKeys, buildPlan.schema, sizeEstimate)
- }
-
- // Used for tests
- def apply(
- input: Iterator[InternalRow],
- buildKeys: Seq[Expression],
- rowSchema: StructType,
+ keyGenerator: UnsafeProjection,
sizeEstimate: Int): HashedRelation = {
- // TODO: Use BytesToBytesMap.
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
- val toUnsafe = UnsafeProjection.create(rowSchema)
- val keyGenerator = UnsafeProjection.create(buildKeys)
// Create a mapping of buildKeys -> rows
while (input.hasNext) {
- val currentRow = input.next()
- val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) {
- currentRow.asInstanceOf[UnsafeRow]
- } else {
- toUnsafe(currentRow)
- }
+ val unsafeRow = input.next().asInstanceOf[UnsafeRow]
val rowKey = keyGenerator(unsafeRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 874712a4e739f..26a664104d6fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -46,7 +46,7 @@ case class LeftSemiJoinHash(
val hashSet = buildKeyHashSet(buildIter)
hashSemiJoin(streamIter, hashSet)
} else {
- val hashRelation = buildHashRelation(buildIter)
+ val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 948d0ccebceb0..5439e10a60b2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -45,7 +45,7 @@ case class ShuffledHashJoin(
protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashed = buildHashRelation(buildIter)
+ val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
hashJoin(streamIter, hashed)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
index f54f1edd38ec8..d29b593207c4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
@@ -50,8 +50,8 @@ case class ShuffledHashOuterJoin(
// TODO this probably can be replaced by external sort (sort merged join?)
joinType match {
case LeftOuter =>
- val hashed = buildHashRelation(rightIter)
- val keyGenerator = streamedKeyGenerator()
+ val hashed = HashedRelation(rightIter, buildKeyGenerator)
+ val keyGenerator = streamedKeyGenerator
leftIter.flatMap( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
@@ -59,8 +59,8 @@ case class ShuffledHashOuterJoin(
})
case RightOuter =>
- val hashed = buildHashRelation(leftIter)
- val keyGenerator = streamedKeyGenerator()
+ val hashed = HashedRelation(leftIter, buildKeyGenerator)
+ val keyGenerator = streamedKeyGenerator
rightIter.flatMap ( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index ec084a299649e..ef1c6e57dc08a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -134,8 +134,19 @@ object EvaluatePython {
}
new GenericInternalRowWithSchema(values, struct)
- case (seq: Seq[Any], array: ArrayType) =>
- seq.map(x => toJava(x, array.elementType)).asJava
+ case (a: ArrayData, array: ArrayType) =>
+ val length = a.numElements()
+ val values = new java.util.ArrayList[Any](length)
+ var i = 0
+ while (i < length) {
+ if (a.isNullAt(i)) {
+ values.add(null)
+ } else {
+ values.add(toJava(a.get(i), array.elementType))
+ }
+ i += 1
+ }
+ values
case (obj: Map[_, _], mt: MapType) => obj.map {
case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
@@ -190,10 +201,10 @@ object EvaluatePython {
case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
case (c: java.util.List[_], ArrayType(elementType, _)) =>
- c.map { e => fromJava(e, elementType)}.toSeq
+ new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray)
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
- c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq
+ new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
@@ -267,7 +278,6 @@ object EvaluatePython {
pickler.save(row.values(i))
i += 1
}
- row.values.foreach(pickler.save)
out.write(Opcodes.TUPLE)
out.write(Opcodes.REDUCE)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
new file mode 100644
index 0000000000000..6d903ab23c57f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.collection.ExternalSorter
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// This file defines various sort operators.
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
+/**
+ * Performs a sort on-heap.
+ * @param global when true performs a global sort of all partitions by shuffling the data first
+ * if necessary.
+ */
+case class Sort(
+ sortOrder: Seq[SortOrder],
+ global: Boolean,
+ child: SparkPlan)
+ extends UnaryNode {
+ override def requiredChildDistribution: Seq[Distribution] =
+ if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
+ child.execute().mapPartitions( { iterator =>
+ val ordering = newOrdering(sortOrder, child.output)
+ iterator.map(_.copy()).toArray.sorted(ordering).iterator
+ }, preservesPartitioning = true)
+ }
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
+}
+
+/**
+ * Performs a sort, spilling to disk as needed.
+ * @param global when true performs a global sort of all partitions by shuffling the data first
+ * if necessary.
+ */
+case class ExternalSort(
+ sortOrder: Seq[SortOrder],
+ global: Boolean,
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
+ child.execute().mapPartitions( { iterator =>
+ val ordering = newOrdering(sortOrder, child.output)
+ val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering))
+ sorter.insertAll(iterator.map(r => (r.copy(), null)))
+ val baseIterator = sorter.iterator.map(_._1)
+ // TODO(marmbrus): The complex type signature below thwarts inference for no reason.
+ CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
+ }, preservesPartitioning = true)
+ }
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
+}
+
+/**
+ * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of
+ * Project Tungsten).
+ *
+ * @param global when true performs a global sort of all partitions by shuffling the data first
+ * if necessary.
+ * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
+ * spill every `frequency` records.
+ */
+case class TungstenSort(
+ sortOrder: Seq[SortOrder],
+ global: Boolean,
+ child: SparkPlan,
+ testSpillFrequency: Int = 0)
+ extends UnaryNode {
+
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val schema = child.schema
+ val childOutput = child.output
+ child.execute().mapPartitions({ iter =>
+ val ordering = newOrdering(sortOrder, childOutput)
+
+ // The comparator for comparing prefix
+ val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
+ val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+
+ // The generator for prefix
+ val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+ val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = {
+ prefixProjection.apply(row).getLong(0)
+ }
+ }
+
+ val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
+ if (testSpillFrequency > 0) {
+ sorter.setTestSpillFrequency(testSpillFrequency)
+ }
+ sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+ }, preservesPartitioning = true)
+ }
+
+}
+
+object TungstenSort {
+ /**
+ * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise.
+ */
+ def supportsSchema(schema: StructType): Boolean = {
+ UnsafeExternalRowSorter.supportsSchema(schema)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 78da2840dad69..9329148aa233c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.{DataType, ArrayType, StructField, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame}
private[sql] object FrequentItems extends Logging {
@@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging {
baseCounts
}
)
- val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
+ val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_))
val resultRow = InternalRow(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index cab3db609dd4b..46dc4605a5ccb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -634,7 +634,7 @@ object functions {
* @group normal_funcs
* @since 1.4.0
*/
- def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID()
+ def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID()
/**
* Return an alternative value `r` if `l` is NaN.
@@ -741,7 +741,16 @@ object functions {
* @group normal_funcs
* @since 1.4.0
*/
- def sparkPartitionId(): Column = execution.expressions.SparkPartitionID
+ def sparkPartitionId(): Column = SparkPartitionID()
+
+ /**
+ * The file name of the current Spark task
+ *
+ * Note that this is indeterministic becuase it depends on what is currently being read in.
+ *
+ * @group normal_funcs
+ */
+ def inputFileName(): Column = InputFileName()
/**
* Computes the square root of the specified float value.
@@ -1423,7 +1432,8 @@ object functions {
def round(columnName: String): Column = round(Column(columnName), 0)
/**
- * Returns the value of `e` rounded to `scale` decimal places.
+ * Round the value of `e` to `scale` decimal places if `scale` >= 0
+ * or at integral part when `scale` < 0.
*
* @group math_funcs
* @since 1.5.0
@@ -1431,7 +1441,8 @@ object functions {
def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))
/**
- * Returns the value of the given column rounded to `scale` decimal places.
+ * Round the value of the given column to `scale` decimal places if `scale` >= 0
+ * or at integral part when `scale` < 0.
*
* @group math_funcs
* @since 1.5.0
@@ -1916,6 +1927,14 @@ object functions {
// DateTime functions
//////////////////////////////////////////////////////////////////////////////////////////////
+ /**
+ * Returns the date that is numMonths after startDate.
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def add_months(startDate: Column, numMonths: Int): Column =
+ AddMonths(startDate.expr, Literal(numMonths))
+
/**
* Converts a date/timestamp/string to a value of string in the format specified by the date
* format given by the second argument.
@@ -1948,6 +1967,20 @@ object functions {
def date_format(dateColumnName: String, format: String): Column =
date_format(Column(dateColumnName), format)
+ /**
+ * Returns the date that is `days` days after `start`
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days))
+
+ /**
+ * Returns the date that is `days` days before `start`
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days))
+
/**
* Extracts the year as an integer from a given date/timestamp/string.
* @group datetime_funcs
@@ -2032,6 +2065,16 @@ object functions {
*/
def hour(columnName: String): Column = hour(Column(columnName))
+ /**
+ * Given a date column, returns the last day of the month which the given date belongs to.
+ * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the
+ * month in July 2015.
+ *
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def last_day(e: Column): Column = LastDay(e.expr)
+
/**
* Extracts the minutes as an integer from a given date/timestamp/string.
* @group datetime_funcs
@@ -2046,6 +2089,28 @@ object functions {
*/
def minute(columnName: String): Column = minute(Column(columnName))
+ /*
+ * Returns number of months between dates `date1` and `date2`.
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr)
+
+ /**
+ * Given a date column, returns the first date which is later than the value of the date column
+ * that is on the specified day of the week.
+ *
+ * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first
+ * Sunday after 2015-07-27.
+ *
+ * Day of the week parameter is case insensitive, and accepts:
+ * "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun".
+ *
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr)
+
/**
* Extracts the seconds as an integer from a given date/timestamp/string.
* @group datetime_funcs
@@ -2074,6 +2139,64 @@ object functions {
*/
def weekofyear(columnName: String): Column = weekofyear(Column(columnName))
+ /**
+ * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
+ * representing the timestamp of that moment in the current system time zone in the given
+ * format.
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss"))
+
+ /**
+ * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
+ * representing the timestamp of that moment in the current system time zone in the given
+ * format.
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f))
+
+ /**
+ * Gets current Unix timestamp in seconds.
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss"))
+
+ /**
+ * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds),
+ * using the default timezone and the default locale, return null if fail.
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss"))
+
+ /**
+ * Convert time string with given pattern
+ * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html])
+ * to Unix time stamp (in seconds), return null if fail.
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p))
+
+ /*
+ * Converts the column into DateType.
+ *
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def to_date(e: Column): Column = ToDate(e.expr)
+
+ /**
+ * Returns date truncated to the unit specified by the format.
+ *
+ * @group datetime_funcs
+ * @since 1.5.0
+ */
+ def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format))
+
//////////////////////////////////////////////////////////////////////////////////////////////
// Collection functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
index 0eb3b04007f8d..04ab5e2217882 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
@@ -125,7 +125,7 @@ private[sql] object InferSchema {
* Convert NullType to StringType and remove StructTypes with no fields
*/
private def canonicalizeType: DataType => Option[DataType] = {
- case at@ArrayType(elementType, _) =>
+ case at @ ArrayType(elementType, _) =>
for {
canonicalType <- canonicalizeType(elementType)
} yield {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index 381e7ed54428f..1c309f8794ef3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -110,8 +110,13 @@ private[sql] object JacksonParser {
case (START_OBJECT, st: StructType) =>
convertObject(factory, parser, st)
+ case (START_ARRAY, st: StructType) =>
+ // SPARK-3308: support reading top level JSON arrays and take every element
+ // in such an array as a row
+ convertArray(factory, parser, st)
+
case (START_ARRAY, ArrayType(st, _)) =>
- convertList(factory, parser, st)
+ convertArray(factory, parser, st)
case (START_OBJECT, ArrayType(st, _)) =>
// the business end of SPARK-3308:
@@ -165,16 +170,16 @@ private[sql] object JacksonParser {
builder.result()
}
- private def convertList(
+ private def convertArray(
factory: JsonFactory,
parser: JsonParser,
- schema: DataType): Seq[Any] = {
- val builder = Seq.newBuilder[Any]
+ elementType: DataType): ArrayData = {
+ val values = scala.collection.mutable.ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_ARRAY)) {
- builder += convertField(factory, parser, schema)
+ values += convertField(factory, parser, elementType)
}
- builder.result()
+ new GenericArrayData(values.toArray)
}
private def parseJson(
@@ -201,12 +206,15 @@ private[sql] object JacksonParser {
val parser = factory.createParser(record)
parser.nextToken()
- // to support both object and arrays (see SPARK-3308) we'll start
- // by converting the StructType schema to an ArrayType and let
- // convertField wrap an object into a single value array when necessary.
- convertField(factory, parser, ArrayType(schema)) match {
+ convertField(factory, parser, schema) match {
case null => failedRecord(record)
- case list: Seq[InternalRow @unchecked] => list
+ case row: InternalRow => row :: Nil
+ case array: ArrayData =>
+ if (array.numElements() == 0) {
+ Nil
+ } else {
+ array.toArray().map(_.asInstanceOf[InternalRow])
+ }
case _ =>
sys.error(
s"Failed to parse record $record. Please make sure that each line of the file " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
index e00bd90edb3dd..172db8362afb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
@@ -325,7 +325,7 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = elementConverter
- override def end(): Unit = updater.set(currentArray)
+ override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray))
// NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the
// next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index ea51650fe9039..2332a36468dbc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.parquet
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.ArrayData
// TODO Removes this while fixing SPARK-8848
private[sql] object CatalystConverter {
@@ -32,7 +33,7 @@ private[sql] object CatalystConverter {
val MAP_SCHEMA_NAME = "map"
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
- type ArrayScalaType[T] = Seq[T]
+ type ArrayScalaType[T] = ArrayData
type StructScalaType[T] = InternalRow
type MapScalaType[K, V] = Map[K, V]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index cc6fa2b88663f..b4337a48dbd80 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -39,11 +39,10 @@ import org.apache.parquet.{Log => ParquetLog}
import org.apache.spark.{Logging, Partition => SparkPartition, SparkException}
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD}
import org.apache.spark.rdd.RDD._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD}
import org.apache.spark.sql.execution.datasources.PartitionSpec
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -125,6 +124,9 @@ private[sql] class ParquetRelation(
.map(_.toBoolean)
.getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED))
+ private val mergeRespectSummaries =
+ sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES)
+
private val maybeMetastoreSchema = parameters
.get(ParquetRelation.METASTORE_SCHEMA)
.map(DataType.fromJson(_).asInstanceOf[StructType])
@@ -422,7 +424,21 @@ private[sql] class ParquetRelation(
val filesToTouch =
if (shouldMergeSchemas) {
// Also includes summary files, 'cause there might be empty partition directories.
- (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq
+
+ // If mergeRespectSummaries config is true, we assume that all part-files are the same for
+ // their schema with summary files, so we ignore them when merging schema.
+ // If the config is disabled, which is the default setting, we merge all part-files.
+ // In this mode, we only need to merge schemas contained in all those summary files.
+ // You should enable this configuration only if you are very sure that for the parquet
+ // part-files to read there are corresponding summary files containing correct schema.
+
+ val needMerged: Seq[FileStatus] =
+ if (mergeRespectSummaries) {
+ Seq()
+ } else {
+ dataStatuses
+ }
+ (metadataStatuses ++ commonMetadataStatuses ++ needMerged).toSeq
} else {
// Tries any "_common_metadata" first. Parquet files written by old versions or Parquet
// don't have this.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 78ecfad1d57c6..ec8da38a3d427 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -146,15 +146,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
array: CatalystConverter.ArrayScalaType[_]): Unit = {
val elementType = schema.elementType
writer.startGroup()
- if (array.size > 0) {
+ if (array.numElements() > 0) {
if (schema.containsNull) {
writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0)
var i = 0
- while (i < array.size) {
+ while (i < array.numElements()) {
writer.startGroup()
- if (array(i) != null) {
+ if (!array.isNullAt(i)) {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
- writeValue(elementType, array(i))
+ writeValue(elementType, array.get(i))
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
}
writer.endGroup()
@@ -164,8 +164,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
} else {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
var i = 0
- while (i < array.size) {
- writeValue(elementType, array(i))
+ while (i < array.numElements()) {
+ writeValue(elementType, array.get(i))
i = i + 1
}
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
@@ -293,8 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes))
case BinaryType =>
writer.addBinary(Binary.fromByteArray(record.getBinary(index)))
- case DecimalType.Fixed(precision, _) =>
- writeDecimal(record.getDecimal(index), precision)
+ case DecimalType.Fixed(precision, scale) =>
+ writeDecimal(record.getDecimal(index, precision, scale), precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 72c42f4fe376b..2c669bb59a0b5 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -30,7 +30,6 @@
import scala.collection.JavaConversions;
import scala.collection.Seq;
-import scala.collection.mutable.Buffer;
import java.io.Serializable;
import java.util.Arrays;
@@ -168,10 +167,10 @@ public void testCreateDataFrameFromJavaBeans() {
for (int i = 0; i < result.length(); i++) {
Assert.assertEquals(bean.getB()[i], result.apply(i));
}
- Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello");
+ Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello");
Assert.assertArrayEquals(
bean.getC().get("hello"),
- Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer)));
+ Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer)));
Seq d = first.getAs(3);
Assert.assertEquals(bean.getD().size(), d.length());
for (int i = 0; i < d.length(); i++) {
@@ -227,4 +226,13 @@ public void testCovariance() {
Double result = df.stat().cov("a", "b");
Assert.assertTrue(Math.abs(result) < 1e-6);
}
+
+ @Test
+ public void testSampleBy() {
+ DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
+ DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
+ Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
+ Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)};
+ Assert.assertArrayEquals(expected, actual);
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 1f9f7118c3f04..eb64684ae0fd9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -19,16 +19,19 @@ package org.apache.spark.sql
import org.scalatest.Matchers._
-import org.apache.spark.sql.execution.Project
+import org.apache.spark.sql.execution.{Project, TungstenProject}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
+import org.apache.spark.sql.test.SQLTestUtils
-class ColumnExpressionSuite extends QueryTest {
+class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
import org.apache.spark.sql.TestData._
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
+ override def sqlContext(): SQLContext = ctx
+
test("alias") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
assert(df.select(df("a").as("b")).columns.head === "b")
@@ -489,6 +492,18 @@ class ColumnExpressionSuite extends QueryTest {
)
}
+ test("InputFileName") {
+ withTempPath { dir =>
+ val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id")
+ data.write.parquet(dir.getCanonicalPath)
+ val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName())
+ .head.getString(0)
+ assert(answer.contains(dir.getCanonicalPath))
+
+ checkAnswer(data.select(inputFileName()).limit(1), Row(""))
+ }
+ }
+
test("lift alias out of cast") {
compareExpressions(
col("1234").as("name").cast("int").expr,
@@ -523,6 +538,7 @@ class ColumnExpressionSuite extends QueryTest {
def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = {
val projects = df.queryExecution.executedPlan.collect {
case project: Project => project
+ case tungstenProject: TungstenProject => tungstenProject
}
assert(projects.size === expectedNumProjects)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index b26d3ab253a1d..228ece8065151 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.types.{BinaryType, DecimalType}
class DataFrameAggregateSuite extends QueryTest {
@@ -191,4 +191,13 @@ class DataFrameAggregateSuite extends QueryTest {
Row(null))
}
+ test("aggregation can't work on binary type") {
+ val df = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType)
+ intercept[AnalysisException] {
+ df.groupBy("c").agg(count("*"))
+ }
+ intercept[AnalysisException] {
+ df.distinct
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 7ba4ba73e0cc9..07a675e64f527 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -21,9 +21,9 @@ import java.util.Random
import org.scalatest.Matchers._
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.functions.col
-class DataFrameStatSuite extends SparkFunSuite {
+class DataFrameStatSuite extends QueryTest {
private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
import sqlCtx.implicits._
@@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite {
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}
+
+ test("sampleBy") {
+ val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
+ val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
+ checkAnswer(
+ sampled.groupBy("key").count().orderBy("key"),
+ Seq(Row(0, 5), Row(1, 8)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f67f2c60c0e16..97beae2f85c50 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -23,40 +23,38 @@ import scala.language.postfixOps
import scala.util.Random
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
+import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.json.JSONRelation
+import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils}
class DataFrameSuite extends QueryTest with SQLTestUtils {
import org.apache.spark.sql.TestData._
- lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
-
- def sqlContext: SQLContext = ctx
+ lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
+ import sqlContext.implicits._
test("analysis error should be eagerly reported") {
- val oldSetting = ctx.conf.dataFrameEagerAnalysis
// Eager analysis.
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true)
-
- intercept[Exception] { testData.select('nonExistentName) }
- intercept[Exception] {
- testData.groupBy('key).agg(Map("nonExistentName" -> "sum"))
- }
- intercept[Exception] {
- testData.groupBy("nonExistentName").agg(Map("key" -> "sum"))
- }
- intercept[Exception] {
- testData.groupBy($"abcd").agg(Map("key" -> "sum"))
+ withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") {
+ intercept[Exception] { testData.select('nonExistentName) }
+ intercept[Exception] {
+ testData.groupBy('key).agg(Map("nonExistentName" -> "sum"))
+ }
+ intercept[Exception] {
+ testData.groupBy("nonExistentName").agg(Map("key" -> "sum"))
+ }
+ intercept[Exception] {
+ testData.groupBy($"abcd").agg(Map("key" -> "sum"))
+ }
}
// No more eager analysis once the flag is turned off
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false)
- testData.select('nonExistentName)
-
- // Set the flag back to original value before this test.
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting)
+ withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") {
+ testData.select('nonExistentName)
+ }
}
test("dataframe toString") {
@@ -74,21 +72,18 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}
test("invalid plan toString, debug mode") {
- val oldSetting = ctx.conf.dataFrameEagerAnalysis
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true)
-
// Turn on debug mode so we can see invalid query plans.
import org.apache.spark.sql.execution.debug._
- ctx.debug()
- val badPlan = testData.select('badColumn)
+ withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") {
+ sqlContext.debug()
- assert(badPlan.toString contains badPlan.queryExecution.toString,
- "toString on bad query plans should include the query execution but was:\n" +
- badPlan.toString)
+ val badPlan = testData.select('badColumn)
- // Set the flag back to original value before this test.
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting)
+ assert(badPlan.toString contains badPlan.queryExecution.toString,
+ "toString on bad query plans should include the query execution but was:\n" +
+ badPlan.toString)
+ }
}
test("access complex data") {
@@ -104,8 +99,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}
test("empty data frame") {
- assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String])
- assert(ctx.emptyDataFrame.count() === 0)
+ assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String])
+ assert(sqlContext.emptyDataFrame.count() === 0)
}
test("head and take") {
@@ -341,7 +336,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}
test("replace column using withColumn") {
- val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
+ val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df3 = df2.withColumn("x", df2("x") + 1)
checkAnswer(
df3.select("x"),
@@ -422,7 +417,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
test("randomSplit") {
val n = 600
- val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")
@@ -491,6 +486,23 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)
}
+ test("inputFiles") {
+ val fakeRelation1 = new ParquetRelation(Array("/my/path", "/my/other/path"),
+ Some(testData.schema), None, Map.empty)(sqlContext)
+ val df1 = DataFrame(sqlContext, LogicalRelation(fakeRelation1))
+ assert(df1.inputFiles.toSet == fakeRelation1.paths.toSet)
+
+ val fakeRelation2 = new JSONRelation("/json/path", 1, Some(testData.schema), sqlContext)
+ val df2 = DataFrame(sqlContext, LogicalRelation(fakeRelation2))
+ assert(df2.inputFiles.toSet == fakeRelation2.path.toSet)
+
+ val unionDF = df1.unionAll(df2)
+ assert(unionDF.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path)
+
+ val filtered = df1.filter("false").unionAll(df2.intersect(df2))
+ assert(filtered.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path)
+ }
+
ignore("show") {
// This test case is intended ignored, but to make sure it compiles correctly
testData.select($"*").show()
@@ -499,7 +511,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
test("showString: truncate = [true, false]") {
val longString = Array.fill(21)("1").mkString
- val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF()
+ val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = """+---------------------+
||_1 |
|+---------------------+
@@ -589,21 +601,17 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}
test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
- val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
+ val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
- val df = ctx.createDataFrame(rowRDD, schema)
+ val df = sqlContext.createDataFrame(rowRDD, schema)
df.rdd.collect()
}
- test("SPARK-6899") {
- val originalValue = ctx.conf.codegenEnabled
- ctx.setConf(SQLConf.CODEGEN_ENABLED, true)
- try{
+ test("SPARK-6899: type should match when using codegen") {
+ withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
checkAnswer(
decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))
- } finally {
- ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
}
}
@@ -615,14 +623,14 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}
test("SPARK-7551: support backticks for DataFrame attribute resolution") {
- val df = ctx.read.json(ctx.sparkContext.makeRDD(
+ val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df.select(df("`a.b`.c.`d..e`.`f`")),
Row(1)
)
- val df2 = ctx.read.json(ctx.sparkContext.makeRDD(
+ val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df2.select(df2("`a b`.c.d e.f")),
@@ -642,7 +650,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}
test("SPARK-7324 dropDuplicates") {
- val testData = ctx.sparkContext.parallelize(
+ val testData = sqlContext.sparkContext.parallelize(
(2, 1, 2) :: (1, 1, 1) ::
(1, 2, 1) :: (2, 1, 2) ::
(2, 2, 2) :: (2, 2, 1) ::
@@ -690,49 +698,49 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
test("SPARK-7150 range api") {
// numSlice is greater than length
- val res1 = ctx.range(0, 10, 1, 15).select("id")
+ val res1 = sqlContext.range(0, 10, 1, 15).select("id")
assert(res1.count == 10)
assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
- val res2 = ctx.range(3, 15, 3, 2).select("id")
+ val res2 = sqlContext.range(3, 15, 3, 2).select("id")
assert(res2.count == 4)
assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
- val res3 = ctx.range(1, -2).select("id")
+ val res3 = sqlContext.range(1, -2).select("id")
assert(res3.count == 0)
// start is positive, end is negative, step is negative
- val res4 = ctx.range(1, -2, -2, 6).select("id")
+ val res4 = sqlContext.range(1, -2, -2, 6).select("id")
assert(res4.count == 2)
assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
// start, end, step are negative
- val res5 = ctx.range(-3, -8, -2, 1).select("id")
+ val res5 = sqlContext.range(-3, -8, -2, 1).select("id")
assert(res5.count == 3)
assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
// start, end are negative, step is positive
- val res6 = ctx.range(-8, -4, 2, 1).select("id")
+ val res6 = sqlContext.range(-8, -4, 2, 1).select("id")
assert(res6.count == 2)
assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
- val res7 = ctx.range(-10, -9, -20, 1).select("id")
+ val res7 = sqlContext.range(-10, -9, -20, 1).select("id")
assert(res7.count == 0)
- val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+ val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
assert(res8.count == 3)
assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
- val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+ val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
assert(res9.count == 2)
assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
// only end provided as argument
- val res10 = ctx.range(10).select("id")
+ val res10 = sqlContext.range(10).select("id")
assert(res10.count == 10)
assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
- val res11 = ctx.range(-1).select("id")
+ val res11 = sqlContext.range(-1).select("id")
assert(res11.count == 0)
}
@@ -799,13 +807,13 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
// pass case: parquet table (HadoopFsRelation)
df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath)
- val pdf = ctx.read.parquet(tempParquetFile.getCanonicalPath)
+ val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath)
pdf.registerTempTable("parquet_base")
insertion.write.insertInto("parquet_base")
// pass case: json table (InsertableRelation)
df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath)
- val jdf = ctx.read.json(tempJsonFile.getCanonicalPath)
+ val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath)
jdf.registerTempTable("json_base")
insertion.write.mode(SaveMode.Overwrite).insertInto("json_base")
@@ -825,11 +833,54 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))
// error case: insert into an OneRowRelation
- new DataFrame(ctx, OneRowRelation).registerTempTable("one_row")
+ new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row")
val e3 = intercept[AnalysisException] {
insertion.write.insertInto("one_row")
}
assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed."))
}
}
+
+ test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") {
+ // Make sure we can pass this test for both codegen mode and interpreted mode.
+ withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
+ val df = testData.select(rand(33))
+ assert(df.showString(5) == df.showString(5))
+ }
+
+ withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
+ val df = testData.select(rand(33))
+ assert(df.showString(5) == df.showString(5))
+ }
+
+ // We will reuse the same Expression object for LocalRelation.
+ val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33))
+ assert(df.showString(5) == df.showString(5))
+ }
+
+ test("SPARK-8609: local DataFrame with random columns should return same value after sort") {
+ // Make sure we can pass this test for both codegen mode and interpreted mode.
+ withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
+ checkAnswer(testData.sort(rand(33)), testData.sort(rand(33)))
+ }
+
+ withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
+ checkAnswer(testData.sort(rand(33)), testData.sort(rand(33)))
+ }
+
+ // We will reuse the same Expression object for LocalRelation.
+ val df = (1 to 10).map(Tuple1.apply).toDF()
+ checkAnswer(df.sort(rand(33)), df.sort(rand(33)))
+ }
+
+ test("SPARK-9083: sort with non-deterministic expressions") {
+ import org.apache.spark.util.random.XORShiftRandom
+
+ val seed = 33
+ val df = (1 to 100).map(Tuple1.apply).toDF("i")
+ val random = new XORShiftRandom(seed)
+ val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1)
+ val actual = df.sort(rand(seed)).collect().map(_.getInt(0))
+ assert(expected === actual)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
new file mode 100644
index 0000000000000..bf8ef9a97bc60
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types._
+
+/**
+ * An end-to-end test suite specifically for testing Tungsten (Unsafe/CodeGen) mode.
+ *
+ * This is here for now so I can make sure Tungsten project is tested without refactoring existing
+ * end-to-end test infra. In the long run this should just go away.
+ */
+class DataFrameTungstenSuite extends QueryTest with SQLTestUtils {
+
+ override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
+ import sqlContext.implicits._
+
+ test("test simple types") {
+ withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
+ val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
+ assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2))
+ }
+ }
+
+ test("test struct type") {
+ withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
+ val struct = Row(1, 2L, 3.0F, 3.0)
+ val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct)))
+
+ val schema = new StructType()
+ .add("a", IntegerType)
+ .add("b",
+ new StructType()
+ .add("b1", IntegerType)
+ .add("b2", LongType)
+ .add("b3", FloatType)
+ .add("b4", DoubleType))
+
+ val df = sqlContext.createDataFrame(data, schema)
+ assert(df.select("b").first() === Row(struct))
+ }
+ }
+
+ test("test nested struct type") {
+ withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
+ val innerStruct = Row(1, "abcd")
+ val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg")
+ val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct)))
+
+ val schema = new StructType()
+ .add("a", IntegerType)
+ .add("b",
+ new StructType()
+ .add("b1", IntegerType)
+ .add("b2", LongType)
+ .add("b3", FloatType)
+ .add("b4", DoubleType)
+ .add("b5", new StructType()
+ .add("b5a", IntegerType)
+ .add("b5b", StringType))
+ .add("b6", StringType))
+
+ val df = sqlContext.createDataFrame(data, schema)
+ assert(df.select("b").first() === Row(outerStruct))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
index 9e80ae86920d9..8c596fad74ee4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
@@ -20,13 +20,36 @@ package org.apache.spark.sql
import java.sql.{Timestamp, Date}
import java.text.SimpleDateFormat
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.functions._
+import org.apache.spark.unsafe.types.CalendarInterval
class DateFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
+ test("function current_date") {
+ val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
+ val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
+ val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0))
+ val d2 = DateTimeUtils.fromJavaDate(
+ ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
+ val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis())
+ assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1)
+ }
+
+ // This is a bad test. SPARK-9196 will fix it and re-enable it.
+ ignore("function current_timestamp") {
+ val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
+ checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))
+ // Execution in one query should return the same value
+ checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
+ Row(true))
+ assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
+ 0).getTime - System.currentTimeMillis()) < 5000)
+ }
+
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
val sdfDate = new SimpleDateFormat("yyyy-MM-dd")
val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime)
@@ -184,4 +207,242 @@ class DateFunctionsSuite extends QueryTest {
Row(15, 15, 15))
}
+ test("function date_add") {
+ val st1 = "2015-06-01 12:34:56"
+ val st2 = "2015-06-02 12:34:56"
+ val t1 = Timestamp.valueOf(st1)
+ val t2 = Timestamp.valueOf(st2)
+ val s1 = "2015-06-01"
+ val s2 = "2015-06-02"
+ val d1 = Date.valueOf(s1)
+ val d2 = Date.valueOf(s2)
+ val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss")
+ checkAnswer(
+ df.select(date_add(col("d"), 1)),
+ Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03"))))
+ checkAnswer(
+ df.select(date_add(col("t"), 3)),
+ Seq(Row(Date.valueOf("2015-06-04")), Row(Date.valueOf("2015-06-05"))))
+ checkAnswer(
+ df.select(date_add(col("s"), 5)),
+ Seq(Row(Date.valueOf("2015-06-06")), Row(Date.valueOf("2015-06-07"))))
+ checkAnswer(
+ df.select(date_add(col("ss"), 7)),
+ Seq(Row(Date.valueOf("2015-06-08")), Row(Date.valueOf("2015-06-09"))))
+
+ checkAnswer(df.selectExpr("DATE_ADD(null, 1)"), Seq(Row(null), Row(null)))
+ checkAnswer(
+ df.selectExpr("""DATE_ADD(d, 1)"""),
+ Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03"))))
+ }
+
+ test("function date_sub") {
+ val st1 = "2015-06-01 12:34:56"
+ val st2 = "2015-06-02 12:34:56"
+ val t1 = Timestamp.valueOf(st1)
+ val t2 = Timestamp.valueOf(st2)
+ val s1 = "2015-06-01"
+ val s2 = "2015-06-02"
+ val d1 = Date.valueOf(s1)
+ val d2 = Date.valueOf(s2)
+ val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss")
+ checkAnswer(
+ df.select(date_sub(col("d"), 1)),
+ Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01"))))
+ checkAnswer(
+ df.select(date_sub(col("t"), 1)),
+ Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01"))))
+ checkAnswer(
+ df.select(date_sub(col("s"), 1)),
+ Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01"))))
+ checkAnswer(
+ df.select(date_sub(col("ss"), 1)),
+ Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01"))))
+ checkAnswer(
+ df.select(date_sub(lit(null), 1)).limit(1), Row(null))
+
+ checkAnswer(df.selectExpr("""DATE_SUB(d, null)"""), Seq(Row(null), Row(null)))
+ checkAnswer(
+ df.selectExpr("""DATE_SUB(d, 1)"""),
+ Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01"))))
+ }
+
+ test("time_add") {
+ val t1 = Timestamp.valueOf("2015-07-31 23:59:59")
+ val t2 = Timestamp.valueOf("2015-12-31 00:00:00")
+ val d1 = Date.valueOf("2015-07-31")
+ val d2 = Date.valueOf("2015-12-31")
+ val i = new CalendarInterval(2, 2000000L)
+ val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d")
+ checkAnswer(
+ df.selectExpr(s"d + $i"),
+ Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2016-02-29"))))
+ checkAnswer(
+ df.selectExpr(s"t + $i"),
+ Seq(Row(Timestamp.valueOf("2015-10-01 00:00:01")),
+ Row(Timestamp.valueOf("2016-02-29 00:00:02"))))
+ }
+
+ test("time_sub") {
+ val t1 = Timestamp.valueOf("2015-10-01 00:00:01")
+ val t2 = Timestamp.valueOf("2016-02-29 00:00:02")
+ val d1 = Date.valueOf("2015-09-30")
+ val d2 = Date.valueOf("2016-02-29")
+ val i = new CalendarInterval(2, 2000000L)
+ val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d")
+ checkAnswer(
+ df.selectExpr(s"d - $i"),
+ Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-12-30"))))
+ checkAnswer(
+ df.selectExpr(s"t - $i"),
+ Seq(Row(Timestamp.valueOf("2015-07-31 23:59:59")),
+ Row(Timestamp.valueOf("2015-12-31 00:00:00"))))
+ }
+
+ test("function add_months") {
+ val d1 = Date.valueOf("2015-08-31")
+ val d2 = Date.valueOf("2015-02-28")
+ val df = Seq((1, d1), (2, d2)).toDF("n", "d")
+ checkAnswer(
+ df.select(add_months(col("d"), 1)),
+ Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-31"))))
+ checkAnswer(
+ df.selectExpr("add_months(d, -1)"),
+ Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-31"))))
+ }
+
+ test("function months_between") {
+ val d1 = Date.valueOf("2015-07-31")
+ val d2 = Date.valueOf("2015-02-16")
+ val t1 = Timestamp.valueOf("2014-09-30 23:30:00")
+ val t2 = Timestamp.valueOf("2015-09-16 12:00:00")
+ val s1 = "2014-09-15 11:30:00"
+ val s2 = "2015-10-01 00:00:00"
+ val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s")
+ checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0)))
+ checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5)))
+ }
+
+ test("function last_day") {
+ val df1 = Seq((1, "2015-07-23"), (2, "2015-07-24")).toDF("i", "d")
+ val df2 = Seq((1, "2015-07-23 00:11:22"), (2, "2015-07-24 11:22:33")).toDF("i", "t")
+ checkAnswer(
+ df1.select(last_day(col("d"))),
+ Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31"))))
+ checkAnswer(
+ df2.select(last_day(col("t"))),
+ Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31"))))
+ }
+
+ test("function next_day") {
+ val df1 = Seq(("mon", "2015-07-23"), ("tuesday", "2015-07-20")).toDF("dow", "d")
+ val df2 = Seq(("th", "2015-07-23 00:11:22"), ("xx", "2015-07-24 11:22:33")).toDF("dow", "t")
+ checkAnswer(
+ df1.select(next_day(col("d"), "MONDAY")),
+ Seq(Row(Date.valueOf("2015-07-27")), Row(Date.valueOf("2015-07-27"))))
+ checkAnswer(
+ df2.select(next_day(col("t"), "th")),
+ Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30"))))
+ }
+
+ test("function to_date") {
+ val d1 = Date.valueOf("2015-07-22")
+ val d2 = Date.valueOf("2015-07-01")
+ val t1 = Timestamp.valueOf("2015-07-22 10:00:00")
+ val t2 = Timestamp.valueOf("2014-12-31 23:59:59")
+ val s1 = "2015-07-22 10:00:00"
+ val s2 = "2014-12-31"
+ val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s")
+
+ checkAnswer(
+ df.select(to_date(col("t"))),
+ Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
+ checkAnswer(
+ df.select(to_date(col("d"))),
+ Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
+ checkAnswer(
+ df.select(to_date(col("s"))),
+ Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
+
+ checkAnswer(
+ df.selectExpr("to_date(t)"),
+ Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
+ checkAnswer(
+ df.selectExpr("to_date(d)"),
+ Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
+ checkAnswer(
+ df.selectExpr("to_date(s)"),
+ Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
+ }
+
+ test("function trunc") {
+ val df = Seq(
+ (1, Timestamp.valueOf("2015-07-22 10:00:00")),
+ (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t")
+
+ checkAnswer(
+ df.select(trunc(col("t"), "YY")),
+ Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01"))))
+
+ checkAnswer(
+ df.selectExpr("trunc(t, 'Month')"),
+ Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01"))))
+ }
+
+ test("from_unixtime") {
+ val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
+ val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
+ val sdf2 = new SimpleDateFormat(fmt2)
+ val fmt3 = "yy-MM-dd HH-mm-ss"
+ val sdf3 = new SimpleDateFormat(fmt3)
+ val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b")
+ checkAnswer(
+ df.select(from_unixtime(col("a"))),
+ Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000)))))
+ checkAnswer(
+ df.select(from_unixtime(col("a"), fmt2)),
+ Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000)))))
+ checkAnswer(
+ df.select(from_unixtime(col("a"), fmt3)),
+ Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000)))))
+ checkAnswer(
+ df.selectExpr("from_unixtime(a)"),
+ Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000)))))
+ checkAnswer(
+ df.selectExpr(s"from_unixtime(a, '$fmt2')"),
+ Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000)))))
+ checkAnswer(
+ df.selectExpr(s"from_unixtime(a, '$fmt3')"),
+ Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000)))))
+ }
+
+ test("unix_timestamp") {
+ val date1 = Date.valueOf("2015-07-24")
+ val date2 = Date.valueOf("2015-07-25")
+ val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3")
+ val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2")
+ val s1 = "2015/07/24 10:00:00.5"
+ val s2 = "2015/07/25 02:02:02.6"
+ val ss1 = "2015-07-24 10:00:00"
+ val ss2 = "2015-07-25 02:02:02"
+ val fmt = "yyyy/MM/dd HH:mm:ss.S"
+ val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss")
+ checkAnswer(df.select(unix_timestamp(col("ts"))), Seq(
+ Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
+ checkAnswer(df.select(unix_timestamp(col("ss"))), Seq(
+ Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
+ checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq(
+ Row(date1.getTime / 1000L), Row(date2.getTime / 1000L)))
+ checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq(
+ Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
+ checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq(
+ Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
+ checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq(
+ Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
+ checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq(
+ Row(date1.getTime / 1000L), Row(date2.getTime / 1000L)))
+ checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq(
+ Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala
deleted file mode 100644
index 44b915304533c..0000000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.functions._
-
-class DatetimeExpressionsSuite extends QueryTest {
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-
- import ctx.implicits._
-
- lazy val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
-
- test("function current_date") {
- val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
- val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0))
- val d2 = DateTimeUtils.fromJavaDate(
- ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
- val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis())
- assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1)
- }
-
- test("function current_timestamp") {
- checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))
- // Execution in one query should return the same value
- checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
- Row(true))
- assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
- 0).getTime - System.currentTimeMillis()) < 5000)
- }
-
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index dfb2a7e099748..27c08f64649ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.types.BinaryType
class JoinSuite extends QueryTest with BeforeAndAfterEach {
@@ -79,9 +80,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
- ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
- ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
- ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
classOf[ShuffledHashOuterJoin]),
@@ -489,4 +490,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(3, 2) :: Nil)
}
+
+ test("Join can't work on binary type") {
+ val left = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType)
+ val right = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("d").select($"d" cast BinaryType)
+ intercept[AnalysisException] {
+ left.join(right, ($"left.N" === $"right.N"), "full")
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 21256704a5b16..8cf2ef5957d8d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -216,7 +216,8 @@ class MathExpressionsSuite extends QueryTest {
checkAnswer(
ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
- Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142))
+ Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
+ BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 358e319476e83..535011fe3db5b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -227,6 +227,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Seq(Row("1"), Row("2")))
}
+ test("SPARK-8828 sum should return null if all input values are null") {
+ withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") {
+ withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
+ checkAnswer(
+ sql("select sum(a), avg(a) from allNulls"),
+ Seq(Row(null, null))
+ )
+ }
+ withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
+ checkAnswer(
+ sql("select sum(a), avg(a) from allNulls"),
+ Seq(Row(null, null))
+ )
+ }
+ }
+ withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
+ withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
+ checkAnswer(
+ sql("select sum(a), avg(a) from allNulls"),
+ Seq(Row(null, null))
+ )
+ }
+ withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
+ checkAnswer(
+ sql("select sum(a), avg(a) from allNulls"),
+ Seq(Row(null, null))
+ )
+ }
+ }
+ }
+
test("aggregation with codegen") {
val originalValue = sqlContext.conf.codegenEnabled
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
@@ -337,7 +368,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(1))
checkAnswer(
sql("SELECT COALESCE(null, 1, 1.5)"),
- Row(1.toDouble))
+ Row(BigDecimal(1)))
checkAnswer(
sql("SELECT COALESCE(null, null, null)"),
Row(null))
@@ -1203,19 +1234,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("Floating point number format") {
checkAnswer(
- sql("SELECT 0.3"), Row(0.3)
+ sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying())
)
checkAnswer(
- sql("SELECT -0.8"), Row(-0.8)
+ sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying())
)
checkAnswer(
- sql("SELECT .5"), Row(0.5)
+ sql("SELECT .5"), Row(BigDecimal(0.5))
)
checkAnswer(
- sql("SELECT -.18"), Row(-0.18)
+ sql("SELECT -.18"), Row(BigDecimal(-0.18))
)
}
@@ -1248,11 +1279,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
)
checkAnswer(
- sql("SELECT -5.2"), Row(-5.2)
+ sql("SELECT -5.2"), Row(BigDecimal(-5.2))
)
checkAnswer(
- sql("SELECT +6.8"), Row(6.8)
+ sql("SELECT +6.8"), Row(BigDecimal(6.8))
)
checkAnswer(
@@ -1546,10 +1577,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
test("SPARK-8753: add interval type") {
- import org.apache.spark.unsafe.types.Interval
+ import org.apache.spark.unsafe.types.CalendarInterval
val df = sql("select interval 3 years -3 month 7 week 123 microseconds")
- checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 )))
+ checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 )))
withTempPath(f => {
// Currently we don't yet support saving out values of interval data type.
val e = intercept[AnalysisException] {
@@ -1571,20 +1602,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
test("SPARK-8945: add and subtract expressions for interval type") {
- import org.apache.spark.unsafe.types.Interval
- import org.apache.spark.unsafe.types.Interval.MICROS_PER_WEEK
+ import org.apache.spark.unsafe.types.CalendarInterval
+ import org.apache.spark.unsafe.types.CalendarInterval.MICROS_PER_WEEK
val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
- checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123)))
+ checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123)))
- checkAnswer(df.select(df("i") + new Interval(2, 123)),
- Row(new Interval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123)))
+ checkAnswer(df.select(df("i") + new CalendarInterval(2, 123)),
+ Row(new CalendarInterval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123)))
- checkAnswer(df.select(df("i") - new Interval(2, 123)),
- Row(new Interval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123)))
+ checkAnswer(df.select(df("i") - new CalendarInterval(2, 123)),
+ Row(new CalendarInterval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123)))
// unary minus
checkAnswer(df.select(-df("i")),
- Row(new Interval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
+ Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 0f9c986f649a1..8e0ea76d15881 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -57,19 +57,27 @@ class StringFunctionsSuite extends QueryTest {
}
test("string regex_replace / regex_extract") {
- val df = Seq(("100-200", "")).toDF("a", "b")
+ val df = Seq(
+ ("100-200", "(\\d+)-(\\d+)", "300"),
+ ("100-200", "(\\d+)-(\\d+)", "400"),
+ ("100-200", "(\\d+)", "400")).toDF("a", "b", "c")
checkAnswer(
df.select(
regexp_replace($"a", "(\\d+)", "num"),
regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
- Row("num-num", "100"))
-
- checkAnswer(
- df.selectExpr(
- "regexp_replace(a, '(\\d+)', 'num')",
- "regexp_extract(a, '(\\d+)-(\\d+)', 2)"),
- Row("num-num", "200"))
+ Row("num-num", "100") :: Row("num-num", "100") :: Row("num-num", "100") :: Nil)
+
+ // for testing the mutable state of the expression in code gen.
+ // This is a hack way to enable the codegen, thus the codegen is enable by default,
+ // it will still use the interpretProjection if projection followed by a LocalRelation,
+ // hence we add a filter operator.
+ // See the optimizer rule `ConvertToLocalRelation`
+ checkAnswer(
+ df.filter("isnotnull(a)").selectExpr(
+ "regexp_replace(a, b, c)",
+ "regexp_extract(a, b, 1)"),
+ Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil)
}
test("string ascii function") {
@@ -290,5 +298,15 @@ class StringFunctionsSuite extends QueryTest {
df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
Row("5.0000"))
}
+
+ // for testing the mutable state of the expression in code gen.
+ // This is a hack way to enable the codegen, thus the codegen is enable by default,
+ // it will still use the interpretProjection if projection follows by a LocalRelation,
+ // hence we add a filter operator.
+ // See the optimizer rule `ConvertToLocalRelation`
+ val df2 = Seq((5L, 4), (4L, 3), (3L, 2)).toDF("a", "b")
+ checkAnswer(
+ df2.filter("b>0").selectExpr("format_number(a, b)"),
+ Row("5.0000") :: Row("4.000") :: Row("3.00") :: Nil)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 207d7a352c7b3..e340f54850bcc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql
-import java.sql.Timestamp
-
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index c1516b450cbd4..183dc3407b3ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,14 +17,17 @@
package org.apache.spark.sql
+import org.apache.spark.sql.test.SQLTestUtils
case class FunctionResult(f1: String, f2: String)
-class UDFSuite extends QueryTest {
+class UDFSuite extends QueryTest with SQLTestUtils {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
+ override def sqlContext(): SQLContext = ctx
+
test("built-in fixed arity expressions") {
val df = ctx.emptyDataFrame
df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)")
@@ -51,6 +54,25 @@ class UDFSuite extends QueryTest {
df.selectExpr("count(distinct a)")
}
+ test("SPARK-8003 spark_partition_id") {
+ val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
+ df.registerTempTable("tmp_table")
+ checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
+ ctx.dropTempTable("tmp_table")
+ }
+
+ test("SPARK-8005 input_file_name") {
+ withTempPath { dir =>
+ val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id")
+ data.write.parquet(dir.getCanonicalPath)
+ ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
+ val answer = ctx.sql("select input_file_name() from test_table").head().getString(0)
+ assert(answer.contains(dir.getCanonicalPath))
+ assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2)
+ ctx.dropTempTable("test_table")
+ }
+ }
+
test("error reporting for incorrect number of arguments") {
val df = ctx.emptyDataFrame
val e = intercept[AnalysisException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index ad3bb1744cb3c..e72a1bc6c4e20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
-import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.memory.MemoryAllocator
import org.apache.spark.unsafe.types.UTF8String
@@ -67,4 +67,19 @@ class UnsafeRowSuite extends SparkFunSuite {
assert(bytesFromArrayBackedRow === bytesFromOffheapRow)
}
+
+ test("calling getDouble() and getFloat() on null columns") {
+ val row = InternalRow.apply(null, null)
+ val unsafeRow = UnsafeProjection.create(Array[DataType](FloatType, DoubleType)).apply(row)
+ assert(unsafeRow.getFloat(0) === row.getFloat(0))
+ assert(unsafeRow.getDouble(1) === row.getDouble(1))
+ }
+
+ test("calling get(ordinal, datatype) on null columns") {
+ val row = InternalRow.apply(null)
+ val unsafeRow = UnsafeProjection.create(Array[DataType](NullType)).apply(row)
+ for (dataType <- DataTypeTestUtils.atomicTypes) {
+ assert(unsafeRow.get(0, dataType) === null)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 45c9f06941c10..77ed4a9c0d5ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -47,17 +47,17 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
- override def serialize(obj: Any): Seq[Double] = {
+ override def serialize(obj: Any): ArrayData = {
obj match {
case features: MyDenseVector =>
- features.data.toSeq
+ new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
}
}
override def deserialize(datum: Any): MyDenseVector = {
datum match {
- case data: Seq[_] =>
- new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray)
+ case data: ArrayData =>
+ new MyDenseVector(data.toArray.map(_.asInstanceOf[Double]))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 4499a7207031d..66014ddca0596 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -34,8 +34,7 @@ class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[DoubleColumnStats], DOUBLE,
InternalRow(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
- testColumnStats(classOf[FixedDecimalColumnStats],
- FIXED_DECIMAL(15, 10), InternalRow(null, null, 0))
+ testDecimalColumnStats(InternalRow(null, null, 0))
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
@@ -52,7 +51,7 @@ class ColumnStatsSuite extends SparkFunSuite {
}
test(s"$columnStatsName: non-empty") {
- import ColumnarTestUtils._
+ import org.apache.spark.sql.columnar.ColumnarTestUtils._
val columnStats = columnStatsClass.newInstance()
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
@@ -73,4 +72,39 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}
}
+
+ def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) {
+
+ val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName
+ val columnType = FIXED_DECIMAL(15, 10)
+
+ test(s"$columnStatsName: empty") {
+ val columnStats = new FixedDecimalColumnStats(15, 10)
+ columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach {
+ case (actual, expected) => assert(actual === expected)
+ }
+ }
+
+ test(s"$columnStatsName: non-empty") {
+ import org.apache.spark.sql.columnar.ColumnarTestUtils._
+
+ val columnStats = new FixedDecimalColumnStats(15, 10)
+ val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
+ rows.foreach(columnStats.gatherStats(_, 0))
+
+ val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType])
+ val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
+ val stats = columnStats.collectedStatistics
+
+ assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1))
+ assertResult(10, "Wrong null count")(stats.genericGet(2))
+ assertResult(20, "Wrong row count")(stats.genericGet(3))
+ assertResult(stats.genericGet(4), "Wrong size in bytes") {
+ rows.map { row =>
+ if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
+ }.sum
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index 7b75f755918c1..707cd9c6d939b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.IsNull
+import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull}
import org.apache.spark.sql.test.TestSQLContext
class RowFormatConvertersSuite extends SparkPlanTest {
@@ -31,7 +30,7 @@ class RowFormatConvertersSuite extends SparkPlanTest {
private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
assert(!outputsSafe.outputsUnsafeRows)
- private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
+ private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null))
assert(outputsUnsafe.outputsUnsafeRows)
test("planner should insert unsafe->safe conversions when required") {
@@ -41,14 +40,14 @@ class RowFormatConvertersSuite extends SparkPlanTest {
}
test("filter can process unsafe rows") {
- val plan = Filter(IsNull(null), outputsUnsafe)
+ val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
- assert(getConverters(preparedPlan).isEmpty)
+ assert(getConverters(preparedPlan).size === 1)
assert(preparedPlan.outputsUnsafeRows)
}
test("filter can process safe rows") {
- val plan = Filter(IsNull(null), outputsSafe)
+ val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).isEmpty)
assert(!preparedPlan.outputsUnsafeRows)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 6a8f394545816..f46855edfe0de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row}
+import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
@@ -33,11 +33,13 @@ import scala.util.control.NonFatal
*/
class SparkPlanTest extends SparkFunSuite {
+ protected def sqlContext: SQLContext = TestSQLContext
+
/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
- TestSQLContext.implicits.localSeqToDataFrameHolder(data)
+ sqlContext.implicits.localSeqToDataFrameHolder(data)
}
/**
@@ -98,7 +100,7 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -121,7 +123,8 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match {
+ SparkPlanTest.checkAnswer(
+ input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -147,13 +150,14 @@ object SparkPlanTest {
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
- sortAnswers: Boolean): Option[String] = {
+ sortAnswers: Boolean,
+ sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
val expectedAnswer: Seq[Row] = try {
- executePlan(expectedOutputPlan)
+ executePlan(expectedOutputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -168,7 +172,7 @@ object SparkPlanTest {
}
val actualAnswer: Seq[Row] = try {
- executePlan(outputPlan)
+ executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -207,12 +211,13 @@ object SparkPlanTest {
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
- sortAnswers: Boolean): Option[String] = {
+ sortAnswers: Boolean,
+ sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
val sparkAnswer: Seq[Row] = try {
- executePlan(outputPlan)
+ executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -275,10 +280,10 @@ object SparkPlanTest {
}
}
- private def executePlan(outputPlan: SparkPlan): Seq[Row] = {
+ private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = TestSQLContext.prepareForExecution.execute(
+ val resolvedPlan = sqlContext.prepareForExecution.execute(
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
similarity index 70%
rename from sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index 7a4baa9e4a49d..450963547c798 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
-class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
+class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
override def beforeAll(): Unit = {
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
@@ -36,39 +36,21 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
}
- ignore("sort followed by limit should not leak memory") {
- // TODO: this test is going to fail until we implement a proper iterator interface
- // with a close() method.
- TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false")
+ test("sort followed by limit") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
- (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
+ (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)),
(child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
sortAnswers = false
)
}
- test("sort followed by limit") {
- TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false")
- try {
- checkThatPlansAgree(
- (1 to 100).map(v => Tuple1(v)).toDF("a"),
- (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
- (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
- sortAnswers = false
- )
- } finally {
- TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false")
-
- }
- }
-
test("sorting does not crash for large inputs") {
val sortOrder = 'a.asc :: Nil
val stringLength = 1024 * 1024 * 2
checkThatPlansAgree(
Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
- UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
+ TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
Sort(sortOrder, global = true, _: SparkPlan),
sortAnswers = false
)
@@ -88,11 +70,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
StructType(StructField("a", dataType, nullable = true) :: Nil)
)
- assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
+ assert(TungstenSort.supportsSchema(inputDf.schema))
checkThatPlansAgree(
inputDf,
plan => ConvertToSafe(
- UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)),
+ TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)),
Sort(sortOrder, global = true, _: SparkPlan),
sortAnswers = false
)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 9dd2220f0967e..8b1a9b21a96b9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -17,11 +17,12 @@
package org.apache.spark.sql.execution.joins
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkSqlSerializer
-import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.collection.CompactBuffer
@@ -64,27 +65,34 @@ class HashedRelationSuite extends SparkFunSuite {
}
test("UnsafeHashedRelation") {
+ val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
+ val toUnsafe = UnsafeProjection.create(schema)
+ val unsafeData = data.map(toUnsafe(_).copy()).toArray
+
val buildKey = Seq(BoundReference(0, IntegerType, false))
- val schema = StructType(StructField("a", IntegerType, true) :: Nil)
- val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1)
+ val keyGenerator = UnsafeProjection.create(buildKey)
+ val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
assert(hashed.isInstanceOf[UnsafeHashedRelation])
- val toUnsafeKey = UnsafeProjection.create(schema)
- val unsafeData = data.map(toUnsafeKey(_).copy()).toArray
assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
- assert(hashed.get(toUnsafeKey(InternalRow(10))) === null)
+ assert(hashed.get(toUnsafe(InternalRow(10))) === null)
val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
data2 += unsafeData(2).copy()
assert(hashed.get(unsafeData(2)) === data2)
- val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed))
- .asInstanceOf[UnsafeHashedRelation]
+ val os = new ByteArrayOutputStream()
+ val out = new ObjectOutputStream(os)
+ hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
+ out.flush()
+ val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
+ val hashed2 = new UnsafeHashedRelation()
+ hashed2.readExternal(in)
assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
- assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null)
+ assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
assert(hashed2.get(unsafeData(2)) === data2)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 3ac312d6f4c50..f19f22fca7d54 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -422,14 +422,14 @@ class JsonSuite extends QueryTest with TestJsonData {
Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
)
- // Widening to DoubleType
+ // Widening to DecimalType
checkAnswer(
sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"),
- Row(21474836472.2) ::
- Row(92233720368547758071.3) :: Nil
+ Row(BigDecimal("21474836472.2")) ::
+ Row(BigDecimal("92233720368547758071.3")) :: Nil
)
- // Widening to DoubleType
+ // Widening to Double
checkAnswer(
sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"),
Row(101.2) :: Row(21474836471.2) :: Nil
@@ -438,13 +438,13 @@ class JsonSuite extends QueryTest with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 14"),
- Row(92233720368547758071.2)
+ Row(BigDecimal("92233720368547758071.2"))
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
- Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
+ Row(new java.math.BigDecimal("92233720368547758071.2"))
)
// String and Boolean conflict: resolve the type as string.
@@ -503,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 13"),
- Row(14.3) :: Row(92233720368547758071.2) :: Nil
+ Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index c037faf4cfd92..a95f70f2bba69 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.parquet
+import java.io.File
+
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
+import org.apache.spark.util.Utils
/**
* A test suite that tests various Parquet queries.
@@ -123,6 +126,30 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
}
}
+ test("Enabling/disabling merging partfiles when merging parquet schema") {
+ def testSchemaMerging(expectedColumnNumber: Int): Unit = {
+ withTempDir { dir =>
+ val basePath = dir.getCanonicalPath
+ sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
+ // delete summary files, so if we don't merge part-files, one column will not be included.
+ Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata"))
+ Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata"))
+ assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
+ }
+ }
+
+ withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true",
+ SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") {
+ testSchemaMerging(2)
+ }
+
+ withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true",
+ SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") {
+ testSchemaMerging(3)
+ }
+ }
+
test("Enabling/disabling schema merging") {
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 5e189c3563ca8..cfb03ff485b7c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -67,12 +67,12 @@ case class AllDataTypesScan(
override def schema: StructType = userSpecifiedSchema
- override def needConversion: Boolean = false
+ override def needConversion: Boolean = true
override def buildScan(): RDD[Row] = {
sqlContext.sparkContext.parallelize(from to to).map { i =>
- InternalRow(
- UTF8String.fromString(s"str_$i"),
+ Row(
+ s"str_$i",
s"str_$i".getBytes(),
i % 2 == 0,
i.toByte,
@@ -81,19 +81,19 @@ case class AllDataTypesScan(
i.toLong,
i.toFloat,
i.toDouble,
- Decimal(new java.math.BigDecimal(i)),
- Decimal(new java.math.BigDecimal(i)),
- DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)),
- DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)),
- UTF8String.fromString(s"varchar_$i"),
+ new java.math.BigDecimal(i),
+ new java.math.BigDecimal(i),
+ new Date(1970, 1, 1),
+ new Timestamp(20000 + i),
+ s"varchar_$i",
Seq(i, i + 1),
- Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
- Map(i -> UTF8String.fromString(i.toString)),
- Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
- InternalRow(i, UTF8String.fromString(i.toString)),
- InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
- InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
- }.asInstanceOf[RDD[Row]]
+ Seq(Map(s"str_$i" -> Row(i.toLong))),
+ Map(i -> i.toString),
+ Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
+ Row(i, i.toString),
+ Row(Seq(s"str_$i", s"str_${i + 1}"),
+ Row(Seq(new Date(1970, 1, i + 1)))))
+ }
}
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala
similarity index 97%
rename from sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
rename to sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala
index 1fe4fe9629c02..1a5ba20404c4e 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala
@@ -23,16 +23,16 @@ import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.test.TestHive
/**
- * Runs the test cases that are included in the hive distribution with sort merge join is true.
+ * Runs the test cases that are included in the hive distribution with hash joins.
*/
-class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
+class HashJoinCompatibilitySuite extends HiveCompatibilitySuite {
override def beforeAll() {
super.beforeAll()
- TestHive.setConf(SQLConf.SORTMERGE_JOIN, true)
+ TestHive.setConf(SQLConf.SORTMERGE_JOIN, false)
}
override def afterAll() {
- TestHive.setConf(SQLConf.SORTMERGE_JOIN, false)
+ TestHive.setConf(SQLConf.SORTMERGE_JOIN, true)
super.afterAll()
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index b12b3838e615c..ec959cb2194b0 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -822,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udaf_covar_pop",
"udaf_covar_samp",
"udaf_histogram_numeric",
- "udaf_number_format",
"udf2",
"udf5",
"udf6",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index f467500259c91..5926ef9aa388b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -52,9 +52,8 @@ import scala.collection.JavaConversions._
* java.sql.Timestamp
* Complex Types =>
* Map: scala.collection.immutable.Map
- * List: scala.collection.immutable.Seq
- * Struct:
- * [[org.apache.spark.sql.catalyst.InternalRow]]
+ * List: [[org.apache.spark.sql.types.ArrayData]]
+ * Struct: [[org.apache.spark.sql.catalyst.InternalRow]]
* Union: NOT SUPPORTED YET
* The Complex types plays as a container, which can hold arbitrary data types.
*
@@ -297,7 +296,10 @@ private[hive] trait HiveInspectors {
}.toMap
case li: StandardConstantListObjectInspector =>
// take the value from the list inspector object, rather than the input data
- li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq
+ val values = li.getWritableConstantValue
+ .map(unwrap(_, li.getListElementObjectInspector))
+ .toArray
+ new GenericArrayData(values)
// if the value is null, we don't care about the object inspector type
case _ if data == null => null
case poi: VoidObjectInspector => null // always be null for void object inspector
@@ -339,7 +341,10 @@ private[hive] trait HiveInspectors {
}
case li: ListObjectInspector =>
Option(li.getList(data))
- .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq)
+ .map { l =>
+ val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray
+ new GenericArrayData(values)
+ }
.orNull
case mi: MapObjectInspector =>
Option(mi.getMap(data)).map(
@@ -391,7 +396,13 @@ private[hive] trait HiveInspectors {
case loi: ListObjectInspector =>
val wrapper = wrapperFor(loi.getListElementObjectInspector)
- (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null
+ (o: Any) => {
+ if (o != null) {
+ seqAsJavaList(o.asInstanceOf[ArrayData].toArray().map(wrapper))
+ } else {
+ null
+ }
+ }
case moi: MapObjectInspector =>
// The Predef.Map is scala.collection.immutable.Map.
@@ -520,7 +531,7 @@ private[hive] trait HiveInspectors {
case x: ListObjectInspector =>
val list = new java.util.ArrayList[Object]
val tpe = dataType.asInstanceOf[ArrayType].elementType
- a.asInstanceOf[Seq[_]].foreach {
+ a.asInstanceOf[ArrayData].toArray().foreach {
v => list.add(wrap(v, x.getListElementObjectInspector, tpe))
}
list
@@ -634,7 +645,8 @@ private[hive] trait HiveInspectors {
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null)
} else {
val list = new java.util.ArrayList[Object]()
- value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt)))
+ value.asInstanceOf[ArrayData].toArray()
+ .foreach(v => list.add(wrap(v, listObjectInspector, dt)))
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
}
case Literal(value, MapType(keyType, valueType, _)) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 3180c05445c9f..a8c9b4fa71b99 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -274,9 +274,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val metastoreSchema = StructType.fromAttributes(metastoreRelation.output)
val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging
- // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to
+ // NOTE: Instead of passing Metastore schema directly to `ParquetRelation`, we have to
// serialize the Metastore schema to JSON and pass it as a data source option because of the
- // evil case insensitivity issue, which is reconciled within `ParquetRelation2`.
+ // evil case insensitivity issue, which is reconciled within `ParquetRelation`.
val parquetOptions = Map(
ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json,
ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString)
@@ -290,7 +290,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = {
cachedDataSourceTables.getIfPresent(tableIdentifier) match {
case null => None // Cache miss
- case logical@LogicalRelation(parquetRelation: ParquetRelation) =>
+ case logical @ LogicalRelation(parquetRelation: ParquetRelation) =>
// If we have the same paths, same schema, and same partition spec,
// we will use the cached Parquet Relation.
val useCached =
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 2f79b0aad045c..e6df64d2642bc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -874,15 +874,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
}
def matchSerDe(clause: Seq[ASTNode])
- : (Seq[(String, String)], String, Seq[(String, String)]) = clause match {
+ : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match {
case Token("TOK_SERDEPROPS", propsClause) :: Nil =>
val rowFormat = propsClause.map {
case Token(name, Token(value, Nil) :: Nil) => (name, value)
}
- (rowFormat, "", Nil)
+ (rowFormat, None, Nil)
case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil =>
- (Nil, serdeClass, Nil)
+ (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil)
case Token("TOK_SERDENAME", Token(serdeClass, Nil) ::
Token("TOK_TABLEPROPERTIES",
@@ -891,9 +891,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) =>
(name, value)
}
- (Nil, serdeClass, serdeProps)
+ (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)
- case Nil => (Nil, "", Nil)
+ case Nil => (Nil, None, Nil)
}
val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 205e622195f09..7e3342cc84c0e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -17,15 +17,18 @@
package org.apache.spark.sql.hive.execution
-import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader}
+import java.io._
import java.util.Properties
+import javax.annotation.Nullable
import scala.collection.JavaConversions._
+import scala.util.control.NonFatal
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.spark.{TaskContext, Logging}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.CatalystTypeConverters
@@ -56,21 +59,53 @@ case class ScriptTransformation(
override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
protected override def doExecute(): RDD[InternalRow] = {
- child.execute().mapPartitions { iter =>
+ def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val cmd = List("/bin/bash", "-c", script)
val builder = new ProcessBuilder(cmd)
- // We need to start threads connected to the process pipeline:
- // 1) The error msg generated by the script process would be hidden.
- // 2) If the error msg is too big to chock up the buffer, the input logic would be hung
+
val proc = builder.start()
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val errorStream = proc.getErrorStream
- val reader = new BufferedReader(new InputStreamReader(inputStream))
- val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output)
+ // In order to avoid deadlocks, we need to consume the error output of the child process.
+ // To avoid issues caused by large error output, we use a circular buffer to limit the amount
+ // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang
+ // that motivates this.
+ val stderrBuffer = new CircularBuffer(2048)
+ new RedirectThread(
+ errorStream,
+ stderrBuffer,
+ "Thread-ScriptTransformation-STDERR-Consumer").start()
+
+ val outputProjection = new InterpretedProjection(input, child.output)
+
+ // This nullability is a performance optimization in order to avoid an Option.foreach() call
+ // inside of a loop
+ @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null))
+
+ // This new thread will consume the ScriptTransformation's input rows and write them to the
+ // external process. That process's output will be read by this current thread.
+ val writerThread = new ScriptTransformationWriterThread(
+ inputIterator,
+ outputProjection,
+ inputSerde,
+ inputSoi,
+ ioschema,
+ outputStream,
+ proc,
+ stderrBuffer,
+ TaskContext.get()
+ )
+
+ // This nullability is a performance optimization in order to avoid an Option.foreach() call
+ // inside of a loop
+ @Nullable val (outputSerde, outputSoi) = {
+ ioschema.initOutputSerDe(output).getOrElse((null, null))
+ }
- val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
+ val reader = new BufferedReader(new InputStreamReader(inputStream))
+ val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
var cacheRow: InternalRow = null
var curLine: String = null
var eof: Boolean = false
@@ -79,12 +114,26 @@ case class ScriptTransformation(
if (outputSerde == null) {
if (curLine == null) {
curLine = reader.readLine()
- curLine != null
+ if (curLine == null) {
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ false
+ } else {
+ true
+ }
} else {
true
}
} else {
- !eof
+ if (eof) {
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ false
+ } else {
+ true
+ }
}
}
@@ -110,11 +159,11 @@ case class ScriptTransformation(
}
i += 1
})
- return mutableRow
+ mutableRow
} catch {
case e: EOFException =>
eof = true
- return null
+ null
}
}
@@ -127,13 +176,13 @@ case class ScriptTransformation(
val prevLine = curLine
curLine = reader.readLine()
if (!ioschema.schemaLess) {
- new GenericInternalRow(CatalystTypeConverters.convertToCatalyst(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")))
- .asInstanceOf[Array[Any]])
+ new GenericInternalRow(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
+ .map(CatalystTypeConverters.convertToCatalyst))
} else {
- new GenericInternalRow(CatalystTypeConverters.convertToCatalyst(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2))
- .asInstanceOf[Array[Any]])
+ new GenericInternalRow(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
+ .map(CatalystTypeConverters.convertToCatalyst))
}
} else {
val ret = deserialize()
@@ -146,49 +195,83 @@ case class ScriptTransformation(
}
}
- val (inputSerde, inputSoi) = ioschema.initInputSerDe(input)
- val dataOutputStream = new DataOutputStream(outputStream)
- val outputProjection = new InterpretedProjection(input, child.output)
+ writerThread.start()
- // TODO make the 2048 configurable?
- val stderrBuffer = new CircularBuffer(2048)
- // Consume the error stream from the pipeline, otherwise it will be blocked if
- // the pipeline is full.
- new RedirectThread(errorStream, // input stream from the pipeline
- stderrBuffer, // output to a circular buffer
- "Thread-ScriptTransformation-STDERR-Consumer").start()
+ outputIterator
+ }
- // Put the write(output to the pipeline) into a single thread
- // and keep the collector as remain in the main thread.
- // otherwise it will causes deadlock if the data size greater than
- // the pipeline / buffer capacity.
- new Thread(new Runnable() {
- override def run(): Unit = {
- Utils.tryWithSafeFinally {
- iter
- .map(outputProjection)
- .foreach { row =>
- if (inputSerde == null) {
- val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
- ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
-
- outputStream.write(data)
- } else {
- val writable = inputSerde.serialize(
- row.asInstanceOf[GenericInternalRow].values, inputSoi)
- prepareWritable(writable).write(dataOutputStream)
- }
- }
- outputStream.close()
- } {
- if (proc.waitFor() != 0) {
- logError(stderrBuffer.toString) // log the stderr circular buffer
- }
- }
- }
- }, "Thread-ScriptTransformation-Feed").start()
+ child.execute().mapPartitions { iter =>
+ if (iter.hasNext) {
+ processIterator(iter)
+ } else {
+ // If the input iterator has no rows then do not launch the external script.
+ Iterator.empty
+ }
+ }
+ }
+}
- iterator
+private class ScriptTransformationWriterThread(
+ iter: Iterator[InternalRow],
+ outputProjection: Projection,
+ @Nullable inputSerde: AbstractSerDe,
+ @Nullable inputSoi: ObjectInspector,
+ ioschema: HiveScriptIOSchema,
+ outputStream: OutputStream,
+ proc: Process,
+ stderrBuffer: CircularBuffer,
+ taskContext: TaskContext
+ ) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
+
+ setDaemon(true)
+
+ @volatile private var _exception: Throwable = null
+
+ /** Contains the exception thrown while writing the parent iterator to the external process. */
+ def exception: Option[Throwable] = Option(_exception)
+
+ override def run(): Unit = Utils.logUncaughtExceptions {
+ TaskContext.setTaskContext(taskContext)
+
+ val dataOutputStream = new DataOutputStream(outputStream)
+
+ // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
+ // let's use a variable to record whether the `finally` block was hit due to an exception
+ var threwException: Boolean = true
+ try {
+ iter.map(outputProjection).foreach { row =>
+ if (inputSerde == null) {
+ val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
+ ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
+ outputStream.write(data)
+ } else {
+ val writable = inputSerde.serialize(
+ row.asInstanceOf[GenericInternalRow].values, inputSoi)
+ prepareWritable(writable).write(dataOutputStream)
+ }
+ }
+ outputStream.close()
+ threwException = false
+ } catch {
+ case NonFatal(e) =>
+ // An error occurred while writing input, so kill the child process. According to the
+ // Javadoc this call will not throw an exception:
+ _exception = e
+ proc.destroy()
+ throw e
+ } finally {
+ try {
+ if (proc.waitFor() != 0) {
+ logError(stderrBuffer.toString) // log the stderr circular buffer
+ }
+ } catch {
+ case NonFatal(exceptionFromFinallyBlock) =>
+ if (!threwException) {
+ throw exceptionFromFinallyBlock
+ } else {
+ log.error("Exception in finally block", exceptionFromFinallyBlock)
+ }
+ }
}
}
}
@@ -200,33 +283,43 @@ private[hive]
case class HiveScriptIOSchema (
inputRowFormat: Seq[(String, String)],
outputRowFormat: Seq[(String, String)],
- inputSerdeClass: String,
- outputSerdeClass: String,
+ inputSerdeClass: Option[String],
+ outputSerdeClass: Option[String],
inputSerdeProps: Seq[(String, String)],
outputSerdeProps: Seq[(String, String)],
schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors {
- val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"),
- ("TOK_TABLEROWFORMATLINES", "\n"))
+ private val defaultFormat = Map(
+ ("TOK_TABLEROWFORMATFIELD", "\t"),
+ ("TOK_TABLEROWFORMATLINES", "\n")
+ )
val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
- def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = {
- val (columns, columnTypes) = parseAttrs(input)
- val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps)
- (serde, initInputSoi(serde, columns, columnTypes))
+ def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = {
+ inputSerdeClass.map { serdeClass =>
+ val (columns, columnTypes) = parseAttrs(input)
+ val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps)
+ val fieldObjectInspectors = columnTypes.map(toInspector)
+ val objectInspector = ObjectInspectorFactory
+ .getStandardStructObjectInspector(columns, fieldObjectInspectors)
+ .asInstanceOf[ObjectInspector]
+ (serde, objectInspector)
+ }
}
- def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = {
- val (columns, columnTypes) = parseAttrs(output)
- val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps)
- (serde, initOutputputSoi(serde))
+ def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = {
+ outputSerdeClass.map { serdeClass =>
+ val (columns, columnTypes) = parseAttrs(output)
+ val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps)
+ val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector]
+ (serde, structObjectInspector)
+ }
}
- def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
-
+ private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
val columns = attrs.map {
case aref: AttributeReference => aref.name
case e: NamedExpression => e.name
@@ -242,52 +335,25 @@ case class HiveScriptIOSchema (
(columns, columnTypes)
}
- def initSerDe(serdeClassName: String, columns: Seq[String],
- columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = {
+ private def initSerDe(
+ serdeClassName: String,
+ columns: Seq[String],
+ columnTypes: Seq[DataType],
+ serdeProps: Seq[(String, String)]): AbstractSerDe = {
- val serde: AbstractSerDe = if (serdeClassName != "") {
- val trimed_class = serdeClassName.split("'")(1)
- Utils.classForName(trimed_class)
- .newInstance.asInstanceOf[AbstractSerDe]
- } else {
- null
- }
+ val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe]
- if (serde != null) {
- val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
+ val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
- var propsMap = serdeProps.map(kv => {
- (kv._1.split("'")(1), kv._2.split("'")(1))
- }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
- propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
+ var propsMap = serdeProps.map(kv => {
+ (kv._1.split("'")(1), kv._2.split("'")(1))
+ }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
+ propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
- val properties = new Properties()
- properties.putAll(propsMap)
- serde.initialize(null, properties)
- }
+ val properties = new Properties()
+ properties.putAll(propsMap)
+ serde.initialize(null, properties)
serde
}
-
- def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType])
- : ObjectInspector = {
-
- if (inputSerde != null) {
- val fieldObjectInspectors = columnTypes.map(toInspector(_))
- ObjectInspectorFactory
- .getStandardStructObjectInspector(columns, fieldObjectInspectors)
- .asInstanceOf[ObjectInspector]
- } else {
- null
- }
- }
-
- def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = {
- if (outputSerde != null) {
- outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector]
- } else {
- null
- }
- }
}
-
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 8732e9abf8d31..4a13022eddf60 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -431,7 +431,7 @@ private[hive] case class HiveWindowFunction(
// if pivotResult is true, we will get a Seq having the same size with the size
// of the window frame. At here, we will return the result at the position of
// index in the output buffer.
- outputBuffer.asInstanceOf[Seq[Any]].get(index)
+ outputBuffer.asInstanceOf[ArrayData].get(index)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 3662a4352f55d..7bbdef90cd6b9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -56,6 +56,7 @@ object TestHive
.set("spark.sql.test", "")
.set("spark.sql.hive.metastore.barrierPrefixes",
"org.apache.spark.sql.hive.execution.PairSerDe")
+ .set("spark.buffer.pageSize", "4m")
// SPARK-8910
.set("spark.ui.enabled", "false")))
diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 b/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16
deleted file mode 100644
index c6f275a0db131..0000000000000
--- a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16
+++ /dev/null
@@ -1 +0,0 @@
-0.0 NULL NULL NULL
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 0330013f5325e..f719f2e06ab63 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -217,7 +217,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
test("wrap / unwrap Array Type") {
val dt = ArrayType(dataTypes(0))
- val d = row(0) :: row(0) :: Nil
+ val d = new GenericArrayData(Array(row(0), row(0)))
checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt)))
checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
checkValue(d,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index f067ea0d4fc75..bc72b0172a467 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -172,7 +172,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j }
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
- val shj = df.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j }
+ val shj = df.queryExecution.sparkPlan.collect { case j: SortMergeJoin => j }
assert(shj.size === 1,
"ShuffledHashJoin should be planned when BroadcastHashJoin is turned off")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
index 4056dee777574..9b3ede43ee2d1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -17,13 +17,14 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{Row, QueryTest}
case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
+ import ctx.implicits._
test("UDF case insensitive") {
ctx.udf.register("random0", () => { Math.random() })
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
new file mode 100644
index 0000000000000..0875232aede3e
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.scalatest.exceptions.TestFailedException
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.types.StringType
+
+class ScriptTransformationSuite extends SparkPlanTest {
+
+ override def sqlContext: SQLContext = TestHive
+
+ private val noSerdeIOSchema = HiveScriptIOSchema(
+ inputRowFormat = Seq.empty,
+ outputRowFormat = Seq.empty,
+ inputSerdeClass = None,
+ outputSerdeClass = None,
+ inputSerdeProps = Seq.empty,
+ outputSerdeProps = Seq.empty,
+ schemaLess = false
+ )
+
+ private val serdeIOSchema = noSerdeIOSchema.copy(
+ inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName),
+ outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName)
+ )
+
+ test("cat without SerDe") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = child,
+ ioschema = noSerdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+
+ test("cat with LazySimpleSerDe") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = child,
+ ioschema = serdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+
+ test("script transformation should not swallow errors from upstream operators (no serde)") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ val e = intercept[TestFailedException] {
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = ExceptionInjectingOperator(child),
+ ioschema = noSerdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+ assert(e.getMessage().contains("intentional exception"))
+ }
+
+ test("script transformation should not swallow errors from upstream operators (with serde)") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ val e = intercept[TestFailedException] {
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = ExceptionInjectingOperator(child),
+ ioschema = serdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+ assert(e.getMessage().contains("intentional exception"))
+ }
+}
+
+private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode {
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().map { x =>
+ assert(TaskContext.get() != null) // Make sure that TaskContext is defined.
+ Thread.sleep(1000) // This sleep gives the external process time to start.
+ throw new IllegalArgumentException("intentional exception")
+ }
+ }
+ override def output: Seq[Attribute] = child.output
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 65d4e933bf8e9..2780d5b6adbcf 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.conf.Configuration
import org.apache.spark.{SparkException, SparkConf, Logging}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.{MetadataCleaner, Utils}
import org.apache.spark.streaming.scheduler.JobGenerator
@@ -100,7 +101,7 @@ object Checkpoint extends Logging {
}
val path = new Path(checkpointDir)
- val fs = fsOption.getOrElse(path.getFileSystem(new Configuration()))
+ val fs = fsOption.getOrElse(path.getFileSystem(SparkHadoopUtil.get.conf))
if (fs.exists(path)) {
val statuses = fs.listStatus(path)
if (statuses != null) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 92438f1b1fbf7..177e710ace54b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.spark._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.input.FixedLengthBinaryInputFormat
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.serializer.SerializationDebugger
@@ -110,7 +111,7 @@ class StreamingContext private[streaming] (
* Recreate a StreamingContext from a checkpoint file.
* @param path Path to the directory that was specified as the checkpoint directory
*/
- def this(path: String) = this(path, new Configuration)
+ def this(path: String) = this(path, SparkHadoopUtil.get.conf)
/**
* Recreate a StreamingContext from a checkpoint file using an existing SparkContext.
@@ -803,7 +804,7 @@ object StreamingContext extends Logging {
def getActiveOrCreate(
checkpointPath: String,
creatingFunc: () => StreamingContext,
- hadoopConf: Configuration = new Configuration(),
+ hadoopConf: Configuration = SparkHadoopUtil.get.conf,
createOnError: Boolean = false
): StreamingContext = {
ACTIVATION_LOCK.synchronized {
@@ -828,7 +829,7 @@ object StreamingContext extends Logging {
def getOrCreate(
checkpointPath: String,
creatingFunc: () => StreamingContext,
- hadoopConf: Configuration = new Configuration(),
+ hadoopConf: Configuration = SparkHadoopUtil.get.conf,
createOnError: Boolean = false
): StreamingContext = {
val checkpointOption = CheckpointReader.read(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 959ac9c177f81..26383e420101e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -788,7 +788,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F],
- conf: Configuration = new Configuration) {
+ conf: Configuration = dstream.context.sparkContext.hadoopConfiguration) {
dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 40deb6d7ea79a..35cc3ce5cf468 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -33,6 +33,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
import org.apache.spark.api.java.function.{Function0 => JFunction0}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
@@ -136,7 +137,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable {
* Recreate a JavaStreamingContext from a checkpoint file.
* @param path Path to the directory that was specified as the checkpoint directory
*/
- def this(path: String) = this(new StreamingContext(path, new Configuration))
+ def this(path: String) = this(new StreamingContext(path, SparkHadoopUtil.get.conf))
/**
* Re-creates a JavaStreamingContext from a checkpoint file.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index d58c99a8ff321..a6c4cd220e42f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -21,7 +21,9 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDDOperationScope
-import org.apache.spark.streaming.{Time, Duration, StreamingContext}
+import org.apache.spark.streaming.{Duration, StreamingContext, Time}
+import org.apache.spark.streaming.scheduler.RateController
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.Utils
/**
@@ -47,6 +49,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
/** This is an unique identifier for the input stream. */
val id = ssc.getNewInputStreamId()
+ // Keep track of the freshest rate for this stream using the rateEstimator
+ protected[streaming] val rateController: Option[RateController] = None
+
/** A human-readable name of this InputDStream */
private[streaming] def name: String = {
// e.g. FlumePollingDStream -> "Flume polling stream"
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index a50f0efc030ce..646a8c3530a62 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -21,10 +21,11 @@ import scala.reflect.ClassTag
import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.storage.BlockId
-import org.apache.spark.streaming._
+import org.apache.spark.streaming.{StreamingContext, Time}
import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD
import org.apache.spark.streaming.receiver.Receiver
-import org.apache.spark.streaming.scheduler.StreamInputInfo
+import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo}
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.streaming.util.WriteAheadLogUtils
/**
@@ -40,6 +41,17 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils
abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext)
extends InputDStream[T](ssc_) {
+ /**
+ * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
+ */
+ override protected[streaming] val rateController: Option[RateController] = {
+ if (RateController.isBackPressureEnabled(ssc.conf)) {
+ RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) }
+ } else {
+ None
+ }
+ }
+
/**
* Gets the receiver object that will be sent to the worker nodes
* to receive data. This method needs to defined by any specific implementation
@@ -110,4 +122,14 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
}
Some(blockRDD)
}
+
+ /**
+ * A RateController that sends the new rate to receivers, via the receiver tracker.
+ */
+ private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator)
+ extends RateController(id, estimator) {
+ override def publish(rate: Long): Unit =
+ ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
+ }
}
+
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 4af9b6d3b56ab..58bdda7794bf2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -66,6 +66,12 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
}
eventLoop.start()
+ // attach rate controllers of input streams to receive batch completion updates
+ for {
+ inputDStream <- ssc.graph.getInputStreams
+ rateController <- inputDStream.rateController
+ } ssc.addStreamingListener(rateController)
+
listenerBus.start(ssc.sparkContext)
receiverTracker = new ReceiverTracker(ssc)
inputInfoTracker = new InputInfoTracker(ssc)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala
new file mode 100644
index 0000000000000..882ca0676b6ad
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import java.io.ObjectInputStream
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.concurrent.{ExecutionContext, Future}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * A StreamingListener that receives batch completion updates, and maintains
+ * an estimate of the speed at which this stream should ingest messages,
+ * given an estimate computation from a `RateEstimator`
+ */
+private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator)
+ extends StreamingListener with Serializable {
+
+ init()
+
+ protected def publish(rate: Long): Unit
+
+ @transient
+ implicit private var executionContext: ExecutionContext = _
+
+ @transient
+ private var rateLimit: AtomicLong = _
+
+ /**
+ * An initialization method called both from the constructor and Serialization code.
+ */
+ private def init() {
+ executionContext = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update"))
+ rateLimit = new AtomicLong(-1L)
+ }
+
+ private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
+ ois.defaultReadObject()
+ init()
+ }
+
+ /**
+ * Compute the new rate limit and publish it asynchronously.
+ */
+ private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
+ Future[Unit] {
+ val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
+ newRate.foreach { s =>
+ rateLimit.set(s.toLong)
+ publish(getLatestRate())
+ }
+ }
+
+ def getLatestRate(): Long = rateLimit.get()
+
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
+ val elements = batchCompleted.batchInfo.streamIdToInputInfo
+
+ for {
+ processingEnd <- batchCompleted.batchInfo.processingEndTime;
+ workDelay <- batchCompleted.batchInfo.processingDelay;
+ waitDelay <- batchCompleted.batchInfo.schedulingDelay;
+ elems <- elements.get(streamUID).map(_.numRecords)
+ } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
+ }
+}
+
+object RateController {
+ def isBackPressureEnabled(conf: SparkConf): Boolean =
+ conf.getBoolean("spark.streaming.backpressure.enable", false)
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 6270137951b5a..e076fb5ea174b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -223,7 +223,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// Signal the receivers to delete old block data
if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) {
logInfo(s"Cleanup old received batch data: $cleanupThreshTime")
- endpoint.send(CleanupOldBlocks(cleanupThreshTime))
+ synchronized {
+ if (isTrackerStarted) {
+ endpoint.send(CleanupOldBlocks(cleanupThreshTime))
+ }
+ }
}
}
@@ -285,8 +289,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
/** Update a receiver's maximum ingestion rate */
- def sendRateUpdate(streamUID: Int, newRate: Long): Unit = {
- endpoint.send(UpdateReceiverRateLimit(streamUID, newRate))
+ def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized {
+ if (isTrackerStarted) {
+ endpoint.send(UpdateReceiverRateLimit(streamUID, newRate))
+ }
}
/** Add new blocks for the given stream */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
new file mode 100644
index 0000000000000..a08685119e5d5
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler.rate
+
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkException
+
+/**
+ * A component that estimates the rate at wich an InputDStream should ingest
+ * elements, based on updates at every batch completion.
+ */
+private[streaming] trait RateEstimator extends Serializable {
+
+ /**
+ * Computes the number of elements the stream attached to this `RateEstimator`
+ * should ingest per second, given an update on the size and completion
+ * times of the latest batch.
+ *
+ * @param time The timetamp of the current batch interval that just finished
+ * @param elements The number of elements that were processed in this batch
+ * @param processingDelay The time in ms that took for the job to complete
+ * @param schedulingDelay The time in ms that the job spent in the scheduling queue
+ */
+ def compute(
+ time: Long,
+ elements: Long,
+ processingDelay: Long,
+ schedulingDelay: Long): Option[Double]
+}
+
+object RateEstimator {
+
+ /**
+ * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`.
+ *
+ * @return None if there is no configured estimator, otherwise an instance of RateEstimator
+ * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any
+ * known estimators.
+ */
+ def create(conf: SparkConf): Option[RateEstimator] =
+ conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator =>
+ throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
+ }
+}
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index a34f23475804a..e0718f73aa13f 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -1735,6 +1735,7 @@ public Integer call(String s) throws Exception {
@SuppressWarnings("unchecked")
@Test
public void testContextGetOrCreate() throws InterruptedException {
+ ssc.stop();
final SparkConf conf = new SparkConf()
.setMaster("local[2]")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 08faeaa58f419..255376807c957 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -81,39 +81,41 @@ class BasicOperationsSuite extends TestSuiteBase {
test("repartition (more partitions)") {
val input = Seq(1 to 100, 101 to 200, 201 to 300)
val operation = (r: DStream[Int]) => r.repartition(5)
- val ssc = setupStreams(input, operation, 2)
- val output = runStreamsWithPartitions(ssc, 3, 3)
- assert(output.size === 3)
- val first = output(0)
- val second = output(1)
- val third = output(2)
-
- assert(first.size === 5)
- assert(second.size === 5)
- assert(third.size === 5)
-
- assert(first.flatten.toSet.equals((1 to 100).toSet) )
- assert(second.flatten.toSet.equals((101 to 200).toSet))
- assert(third.flatten.toSet.equals((201 to 300).toSet))
+ withStreamingContext(setupStreams(input, operation, 2)) { ssc =>
+ val output = runStreamsWithPartitions(ssc, 3, 3)
+ assert(output.size === 3)
+ val first = output(0)
+ val second = output(1)
+ val third = output(2)
+
+ assert(first.size === 5)
+ assert(second.size === 5)
+ assert(third.size === 5)
+
+ assert(first.flatten.toSet.equals((1 to 100).toSet))
+ assert(second.flatten.toSet.equals((101 to 200).toSet))
+ assert(third.flatten.toSet.equals((201 to 300).toSet))
+ }
}
test("repartition (fewer partitions)") {
val input = Seq(1 to 100, 101 to 200, 201 to 300)
val operation = (r: DStream[Int]) => r.repartition(2)
- val ssc = setupStreams(input, operation, 5)
- val output = runStreamsWithPartitions(ssc, 3, 3)
- assert(output.size === 3)
- val first = output(0)
- val second = output(1)
- val third = output(2)
-
- assert(first.size === 2)
- assert(second.size === 2)
- assert(third.size === 2)
-
- assert(first.flatten.toSet.equals((1 to 100).toSet))
- assert(second.flatten.toSet.equals( (101 to 200).toSet))
- assert(third.flatten.toSet.equals((201 to 300).toSet))
+ withStreamingContext(setupStreams(input, operation, 5)) { ssc =>
+ val output = runStreamsWithPartitions(ssc, 3, 3)
+ assert(output.size === 3)
+ val first = output(0)
+ val second = output(1)
+ val third = output(2)
+
+ assert(first.size === 2)
+ assert(second.size === 2)
+ assert(third.size === 2)
+
+ assert(first.flatten.toSet.equals((1 to 100).toSet))
+ assert(second.flatten.toSet.equals((101 to 200).toSet))
+ assert(third.flatten.toSet.equals((201 to 300).toSet))
+ }
}
test("groupByKey") {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index d308ac05a54fe..67c2d900940ab 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -30,8 +30,10 @@ import org.apache.hadoop.io.{IntWritable, Text}
import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat}
import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
+import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver}
import org.apache.spark.util.{Clock, ManualClock, Utils}
/**
@@ -391,6 +393,32 @@ class CheckpointSuite extends TestSuiteBase {
testCheckpointedOperation(input, operation, output, 7)
}
+ test("recovery maintains rate controller") {
+ ssc = new StreamingContext(conf, batchDuration)
+ ssc.checkpoint(checkpointDir)
+
+ val dstream = new RateLimitInputDStream(ssc) {
+ override val rateController =
+ Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+ }
+ SingletonTestRateReceiver.reset()
+
+ val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2))
+ output.register()
+ runStreams(ssc, 5, 5)
+
+ SingletonTestRateReceiver.reset()
+ ssc = new StreamingContext(checkpointDir)
+ ssc.start()
+ val outputNew = advanceTimeWithRealDelay(ssc, 2)
+
+ eventually(timeout(5.seconds)) {
+ assert(dstream.getCurrentRateLimit === Some(200))
+ }
+ ssc.stop()
+ ssc = null
+ }
+
// This tests whether file input stream remembers what files were seen before
// the master failure and uses them again to process a large window operation.
// It also tests whether batches, whose processing was incomplete due to the
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index b74d67c63a788..ec2852d9a0206 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -325,27 +325,31 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
}
test("test track the number of input stream") {
- val ssc = new StreamingContext(conf, batchDuration)
+ withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
- class TestInputDStream extends InputDStream[String](ssc) {
- def start() { }
- def stop() { }
- def compute(validTime: Time): Option[RDD[String]] = None
- }
+ class TestInputDStream extends InputDStream[String](ssc) {
+ def start() {}
- class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) {
- def getReceiver: Receiver[String] = null
- }
+ def stop() {}
+
+ def compute(validTime: Time): Option[RDD[String]] = None
+ }
+
+ class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) {
+ def getReceiver: Receiver[String] = null
+ }
- // Register input streams
- val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream)
- val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream)
+ // Register input streams
+ val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream)
+ val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream)
- assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length)
- assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length)
- assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams)
- assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i))
- assert(receiverInputStreams.map(_.id) === Array(0, 1))
+ assert(ssc.graph.getInputStreams().length ==
+ receiverInputStreams.length + inputStreams.length)
+ assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length)
+ assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams)
+ assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i))
+ assert(receiverInputStreams.map(_.id) === Array(0, 1))
+ }
}
def testFileStream(newFilesOnly: Boolean) {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
index 6e9d4431090a2..0e64b57e0ffd8 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
@@ -244,7 +244,13 @@ object MasterFailureTest extends Logging {
} catch {
case e: Exception => logError("Error running streaming context", e)
}
- if (killingThread.isAlive) killingThread.interrupt()
+ if (killingThread.isAlive) {
+ killingThread.interrupt()
+ // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is
+ // stopped before running the next test. Otherwise, it's possible that we set SparkEnv.env
+ // to null after the next test creates the new SparkContext and fail the test.
+ killingThread.join()
+ }
ssc.stop()
logInfo("Has been killed = " + killed)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 4bba9691f8aa5..84a5fbb3d95eb 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -120,7 +120,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory)
- val ssc = new StreamingContext(myConf, batchDuration)
+ ssc = new StreamingContext(myConf, batchDuration)
assert(ssc.checkpointDir != null)
}
@@ -369,16 +369,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
}
assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop")
+ var t: Thread = null
// test whether wait exits if context is stopped
failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown
- new Thread() {
+ t = new Thread() {
override def run() {
Thread.sleep(500)
ssc.stop()
}
- }.start()
+ }
+ t.start()
ssc.awaitTermination()
}
+ // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped
+ // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after
+ // the next test creates the new SparkContext and fail the test.
+ t.join()
}
test("awaitTermination after stop") {
@@ -430,16 +436,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
assert(ssc.awaitTerminationOrTimeout(500) === false)
}
+ var t: Thread = null
// test whether awaitTerminationOrTimeout() return true if context is stopped
failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown
- new Thread() {
+ t = new Thread() {
override def run() {
Thread.sleep(500)
ssc.stop()
}
- }.start()
+ }
+ t.start()
assert(ssc.awaitTerminationOrTimeout(10000) === true)
}
+ // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped
+ // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after
+ // the next test creates the new SparkContext and fail the test.
+ t.join()
}
test("getOrCreate") {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index 4bc1dd4a30fc4..d840c349bbbc4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -36,13 +36,22 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
val input = (1 to 4).map(Seq(_)).toSeq
val operation = (d: DStream[Int]) => d.map(x => x)
+ var ssc: StreamingContext = _
+
+ override def afterFunction() {
+ super.afterFunction()
+ if (ssc != null) {
+ ssc.stop()
+ }
+ }
+
// To make sure that the processing start and end times in collected
// information are different for successive batches
override def batchDuration: Duration = Milliseconds(100)
override def actuallyWait: Boolean = true
test("batch info reporting") {
- val ssc = setupStreams(input, operation)
+ ssc = setupStreams(input, operation)
val collector = new BatchInfoCollector
ssc.addStreamingListener(collector)
runStreams(ssc, input.size, input.size)
@@ -107,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
}
test("receiver info reporting") {
- val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000))
+ ssc = new StreamingContext("local[2]", "test", Milliseconds(1000))
val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
inputStream.foreachRDD(_.count)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
new file mode 100644
index 0000000000000..921da773f6c11
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.util.control.NonFatal
+
+import org.scalatest.Matchers._
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+
+class RateControllerSuite extends TestSuiteBase {
+
+ override def useManualClock: Boolean = false
+
+ test("rate controller publishes updates") {
+ val ssc = new StreamingContext(conf, batchDuration)
+ withStreamingContext(ssc) { ssc =>
+ val dstream = new RateLimitInputDStream(ssc)
+ dstream.register()
+ ssc.start()
+
+ eventually(timeout(10.seconds)) {
+ assert(dstream.publishCalls > 0)
+ }
+ }
+ }
+
+ test("publish rates reach receivers") {
+ val ssc = new StreamingContext(conf, batchDuration)
+ withStreamingContext(ssc) { ssc =>
+ val dstream = new RateLimitInputDStream(ssc) {
+ override val rateController =
+ Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+ }
+ dstream.register()
+ SingletonTestRateReceiver.reset()
+ ssc.start()
+
+ eventually(timeout(10.seconds)) {
+ assert(dstream.getCurrentRateLimit === Some(200))
+ }
+ }
+ }
+
+ test("multiple publish rates reach receivers") {
+ val ssc = new StreamingContext(conf, batchDuration)
+ withStreamingContext(ssc) { ssc =>
+ val rates = Seq(100L, 200L, 300L)
+
+ val dstream = new RateLimitInputDStream(ssc) {
+ override val rateController =
+ Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*)))
+ }
+ SingletonTestRateReceiver.reset()
+ dstream.register()
+
+ val observedRates = mutable.HashSet.empty[Long]
+ ssc.start()
+
+ eventually(timeout(20.seconds)) {
+ dstream.getCurrentRateLimit.foreach(observedRates += _)
+ // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver
+ observedRates should contain theSameElementsAs (rates :+ Long.MaxValue)
+ }
+ }
+ }
+}
+
+private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator {
+ private var idx: Int = 0
+
+ private def nextRate(): Double = {
+ val rate = rates(idx)
+ idx = (idx + 1) % rates.size
+ rate
+ }
+
+ def compute(
+ time: Long,
+ elements: Long,
+ processingDelay: Long,
+ schedulingDelay: Long): Option[Double] = Some(nextRate())
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
index 93f920fdc71f1..0418d776ecc9a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
@@ -64,7 +64,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
test("scheduleReceivers: " +
"schedule receivers evenly when there are more receivers than executors") {
- val receivers = (0 until 6).map(new DummyReceiver(_))
+ val receivers = (0 until 6).map(new RateTestReceiver(_))
val executors = (10000 until 10003).map(port => s"localhost:${port}")
val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
val numReceiversOnExecutor = mutable.HashMap[String, Int]()
@@ -79,7 +79,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
test("scheduleReceivers: " +
"schedule receivers evenly when there are more executors than receivers") {
- val receivers = (0 until 3).map(new DummyReceiver(_))
+ val receivers = (0 until 3).map(new RateTestReceiver(_))
val executors = (10000 until 10006).map(port => s"localhost:${port}")
val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
val numReceiversOnExecutor = mutable.HashMap[String, Int]()
@@ -94,8 +94,8 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
}
test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") {
- val receivers = (0 until 3).map(new DummyReceiver(_)) ++
- (3 until 6).map(new DummyReceiver(_, Some("localhost")))
+ val receivers = (0 until 3).map(new RateTestReceiver(_)) ++
+ (3 until 6).map(new RateTestReceiver(_, Some("localhost")))
val executors = (10000 until 10003).map(port => s"localhost:${port}") ++
(10003 until 10006).map(port => s"localhost2:${port}")
val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
@@ -121,7 +121,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
}
test("scheduleReceivers: return empty scheduled executors if no executors") {
- val receivers = (0 until 3).map(new DummyReceiver(_))
+ val receivers = (0 until 3).map(new RateTestReceiver(_))
val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty)
scheduledExecutors.foreach { case (receiverId, executors) =>
assert(executors.isEmpty)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index e2159bd4f225d..afad5f16dbc71 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -29,69 +29,100 @@ import org.apache.spark.storage.StorageLevel
/** Testsuite for receiver scheduling */
class ReceiverTrackerSuite extends TestSuiteBase {
val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test")
- val ssc = new StreamingContext(sparkConf, Milliseconds(100))
test("Receiver tracker - propagates rate limit") {
- object ReceiverStartedWaiter extends StreamingListener {
- @volatile
- var started = false
-
- override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = {
- started = true
+ withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc =>
+ object ReceiverStartedWaiter extends StreamingListener {
+ @volatile
+ var started = false
+
+ override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = {
+ started = true
+ }
}
- }
-
- ssc.addStreamingListener(ReceiverStartedWaiter)
- ssc.scheduler.listenerBus.start(ssc.sc)
-
- val newRateLimit = 100L
- val inputDStream = new RateLimitInputDStream(ssc)
- val tracker = new ReceiverTracker(ssc)
- tracker.start()
- // we wait until the Receiver has registered with the tracker,
- // otherwise our rate update is lost
- eventually(timeout(5 seconds)) {
- assert(ReceiverStartedWaiter.started)
- }
- tracker.sendRateUpdate(inputDStream.id, newRateLimit)
- // this is an async message, we need to wait a bit for it to be processed
- eventually(timeout(3 seconds)) {
- assert(inputDStream.getCurrentRateLimit.get === newRateLimit)
+ ssc.addStreamingListener(ReceiverStartedWaiter)
+ ssc.scheduler.listenerBus.start(ssc.sc)
+ SingletonTestRateReceiver.reset()
+
+ val newRateLimit = 100L
+ val inputDStream = new RateLimitInputDStream(ssc)
+ val tracker = new ReceiverTracker(ssc)
+ tracker.start()
+ try {
+ // we wait until the Receiver has registered with the tracker,
+ // otherwise our rate update is lost
+ eventually(timeout(5 seconds)) {
+ assert(ReceiverStartedWaiter.started)
+ }
+ tracker.sendRateUpdate(inputDStream.id, newRateLimit)
+ // this is an async message, we need to wait a bit for it to be processed
+ eventually(timeout(3 seconds)) {
+ assert(inputDStream.getCurrentRateLimit.get === newRateLimit)
+ }
+ } finally {
+ tracker.stop(false)
+ }
}
}
}
-/** An input DStream with a hard-coded receiver that gives access to internals for testing. */
-private class RateLimitInputDStream(@transient ssc_ : StreamingContext)
+/**
+ * An input DStream with a hard-coded receiver that gives access to internals for testing.
+ *
+ * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test,
+ * or otherwise you may get {{{NotSerializableException}}} when trying to serialize
+ * the receiver.
+ * @see [[[SingletonDummyReceiver]]].
+ */
+private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext)
extends ReceiverInputDStream[Int](ssc_) {
- override def getReceiver(): DummyReceiver = SingletonDummyReceiver
+ override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver
def getCurrentRateLimit: Option[Long] = {
invokeExecutorMethod.getCurrentRateLimit
}
+ @volatile
+ var publishCalls = 0
+
+ override val rateController: Option[RateController] = {
+ Some(new RateController(id, new ConstantEstimator(100.0)) {
+ override def publish(rate: Long): Unit = {
+ publishCalls += 1
+ }
+ })
+ }
+
private def invokeExecutorMethod: ReceiverSupervisor = {
val c = classOf[Receiver[_]]
val ex = c.getDeclaredMethod("executor")
ex.setAccessible(true)
- ex.invoke(SingletonDummyReceiver).asInstanceOf[ReceiverSupervisor]
+ ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor]
}
}
/**
- * A Receiver as an object so we can read its rate limit.
+ * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when
+ * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being
+ * serialized when receivers are installed on executors.
*
* @note It's necessary to be a top-level object, or else serialization would create another
* one on the executor side and we won't be able to read its rate limit.
*/
-private object SingletonDummyReceiver extends DummyReceiver(0)
+private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) {
+
+ /** Reset the object to be usable in another test. */
+ def reset(): Unit = {
+ executor_ = null
+ }
+}
/**
* Dummy receiver implementation
*/
-private class DummyReceiver(receiverId: Int, host: Option[String] = None)
+private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None)
extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
setReceiverId(receiverId)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
index 0891309f956d2..995f1197ccdfd 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
@@ -22,15 +22,24 @@ import java.util.Properties
import org.scalatest.Matchers
import org.apache.spark.scheduler.SparkListenerJobStart
+import org.apache.spark.streaming._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.scheduler._
-import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase}
class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
val input = (1 to 4).map(Seq(_)).toSeq
val operation = (d: DStream[Int]) => d.map(x => x)
+ var ssc: StreamingContext = _
+
+ override def afterFunction() {
+ super.afterFunction()
+ if (ssc != null) {
+ ssc.stop()
+ }
+ }
+
private def createJobStart(
batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = {
val properties = new Properties()
@@ -46,7 +55,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " +
"onReceiverStarted, onReceiverError, onReceiverStopped") {
- val ssc = setupStreams(input, operation)
+ ssc = setupStreams(input, operation)
val listener = new StreamingJobProgressListener(ssc)
val streamIdToInputInfo = Map(
@@ -141,7 +150,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
}
test("Remove the old completed batches when exceeding the limit") {
- val ssc = setupStreams(input, operation)
+ ssc = setupStreams(input, operation)
val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000)
val listener = new StreamingJobProgressListener(ssc)
@@ -158,7 +167,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
}
test("out-of-order onJobStart and onBatchXXX") {
- val ssc = setupStreams(input, operation)
+ ssc = setupStreams(input, operation)
val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000)
val listener = new StreamingJobProgressListener(ssc)
@@ -209,7 +218,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
}
test("detect memory leak") {
- val ssc = setupStreams(input, operation)
+ ssc = setupStreams(input, operation)
val listener = new StreamingJobProgressListener(ssc)
val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000)
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index d0bde69cc1068..198e0684f32f8 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -74,12 +74,6 @@ public final class BytesToBytesMap {
*/
private long pageCursor = 0;
- /**
- * The size of the data pages that hold key and value data. Map entries cannot span multiple
- * pages, so this limits the maximum entry size.
- */
- private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
-
/**
* The maximum number of keys that BytesToBytesMap supports. The hash table has to be
* power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since
@@ -117,6 +111,12 @@ public final class BytesToBytesMap {
private final double loadFactor;
+ /**
+ * The size of the data pages that hold key and value data. Map entries cannot span multiple
+ * pages, so this limits the maximum entry size.
+ */
+ private final long pageSizeBytes;
+
/**
* Number of keys defined in the map.
*/
@@ -153,10 +153,12 @@ public BytesToBytesMap(
TaskMemoryManager memoryManager,
int initialCapacity,
double loadFactor,
+ long pageSizeBytes,
boolean enablePerfMetrics) {
this.memoryManager = memoryManager;
this.loadFactor = loadFactor;
this.loc = new Location();
+ this.pageSizeBytes = pageSizeBytes;
this.enablePerfMetrics = enablePerfMetrics;
if (initialCapacity <= 0) {
throw new IllegalArgumentException("Initial capacity must be greater than 0");
@@ -165,18 +167,26 @@ public BytesToBytesMap(
throw new IllegalArgumentException(
"Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY);
}
+ if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) {
+ throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " +
+ TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
+ }
allocate(initialCapacity);
}
- public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) {
- this(memoryManager, initialCapacity, 0.70, false);
+ public BytesToBytesMap(
+ TaskMemoryManager memoryManager,
+ int initialCapacity,
+ long pageSizeBytes) {
+ this(memoryManager, initialCapacity, 0.70, pageSizeBytes, false);
}
public BytesToBytesMap(
TaskMemoryManager memoryManager,
int initialCapacity,
+ long pageSizeBytes,
boolean enablePerfMetrics) {
- this(memoryManager, initialCapacity, 0.70, enablePerfMetrics);
+ this(memoryManager, initialCapacity, 0.70, pageSizeBytes, enablePerfMetrics);
}
/**
@@ -443,20 +453,20 @@ public void putNewKey(
// must be stored in the same memory page.
// (8 byte key length) (key) (8 byte value length) (value)
final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
- assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker.
+ assert (requiredSize <= pageSizeBytes - 8); // Reserve 8 bytes for the end-of-page marker.
size++;
bitset.set(pos);
// If there's not enough space in the current page, allocate a new page (8 bytes are reserved
// for the end-of-page marker).
- if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) {
+ if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
if (currentDataPage != null) {
// There wasn't enough space in the current page, so write an end-of-page marker:
final Object pageBaseObject = currentDataPage.getBaseObject();
final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
}
- MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES);
+ MemoryBlock newPage = memoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
@@ -538,10 +548,11 @@ public void free() {
/** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
public long getTotalMemoryConsumption() {
- return (
- dataPages.size() * PAGE_SIZE_BYTES +
- bitset.memoryBlock().size() +
- longArray.memoryBlock().size());
+ long totalDataPagesSize = 0L;
+ for (MemoryBlock dataPage : dataPages) {
+ totalDataPagesSize += dataPage.size();
+ }
+ return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size();
}
/**
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 10881969dbc78..dd70df3b1f791 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -58,8 +58,13 @@ public class TaskMemoryManager {
/** The number of entries in the page table. */
private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
- /** Maximum supported data page size */
- private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS);
+ /**
+ * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is
+ * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page
+ * size is limited by the maximum amount of data that can be stored in a long[] array, which is
+ * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes.
+ */
+ public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L;
/** Bit mask for the lower 51 bits of a long. */
private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
@@ -110,9 +115,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) {
* intended for allocating large blocks of memory that will be shared between operators.
*/
public MemoryBlock allocatePage(long size) {
- if (size > MAXIMUM_PAGE_SIZE) {
+ if (size > MAXIMUM_PAGE_SIZE_BYTES) {
throw new IllegalArgumentException(
- "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes");
+ "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
}
final int pageNumber;
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
similarity index 87%
rename from unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
rename to unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
index 71b1a85a818ea..92a5e4f86f234 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
@@ -24,7 +24,7 @@
/**
* The internal representation of interval type.
*/
-public final class Interval implements Serializable {
+public final class CalendarInterval implements Serializable {
public static final long MICROS_PER_MILLI = 1000L;
public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000;
public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60;
@@ -58,7 +58,7 @@ private static long toLong(String s) {
}
}
- public static Interval fromString(String s) {
+ public static CalendarInterval fromString(String s) {
if (s == null) {
return null;
}
@@ -75,40 +75,40 @@ public static Interval fromString(String s) {
microseconds += toLong(m.group(7)) * MICROS_PER_SECOND;
microseconds += toLong(m.group(8)) * MICROS_PER_MILLI;
microseconds += toLong(m.group(9));
- return new Interval((int) months, microseconds);
+ return new CalendarInterval((int) months, microseconds);
}
}
public final int months;
public final long microseconds;
- public Interval(int months, long microseconds) {
+ public CalendarInterval(int months, long microseconds) {
this.months = months;
this.microseconds = microseconds;
}
- public Interval add(Interval that) {
+ public CalendarInterval add(CalendarInterval that) {
int months = this.months + that.months;
long microseconds = this.microseconds + that.microseconds;
- return new Interval(months, microseconds);
+ return new CalendarInterval(months, microseconds);
}
- public Interval subtract(Interval that) {
+ public CalendarInterval subtract(CalendarInterval that) {
int months = this.months - that.months;
long microseconds = this.microseconds - that.microseconds;
- return new Interval(months, microseconds);
+ return new CalendarInterval(months, microseconds);
}
- public Interval negate() {
- return new Interval(-this.months, -this.microseconds);
+ public CalendarInterval negate() {
+ return new CalendarInterval(-this.months, -this.microseconds);
}
@Override
public boolean equals(Object other) {
if (this == other) return true;
- if (other == null || !(other instanceof Interval)) return false;
+ if (other == null || !(other instanceof CalendarInterval)) return false;
- Interval o = (Interval) other;
+ CalendarInterval o = (CalendarInterval) other;
return this.months == o.months && this.microseconds == o.microseconds;
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 85381cf0ef425..c38953f65d7d7 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -65,6 +65,19 @@ public static UTF8String fromBytes(byte[] bytes) {
}
}
+ /**
+ * Creates an UTF8String from byte array, which should be encoded in UTF-8.
+ *
+ * Note: `bytes` will be hold by returned UTF8String.
+ */
+ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) {
+ if (bytes != null) {
+ return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes);
+ } else {
+ return null;
+ }
+ }
+
/**
* Creates an UTF8String from String.
*/
@@ -89,10 +102,10 @@ public static UTF8String blankString(int length) {
return fromBytes(spaces);
}
- protected UTF8String(Object base, long offset, int size) {
+ protected UTF8String(Object base, long offset, int numBytes) {
this.base = base;
this.offset = offset;
- this.numBytes = size;
+ this.numBytes = numBytes;
}
/**
@@ -137,6 +150,32 @@ public int numChars() {
return len;
}
+ /**
+ * Returns a 64-bit integer that can be used as the prefix used in sorting.
+ */
+ public long getPrefix() {
+ // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string.
+ // If size is 0, just return 0.
+ // If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and
+ // use a getInt to fetch the prefix.
+ // If size is greater than 4, assume we have at least 8 bytes of data to fetch.
+ // After getting the data, we use a mask to mask out data that is not part of the string.
+ long p;
+ if (numBytes >= 8) {
+ p = PlatformDependent.UNSAFE.getLong(base, offset);
+ } else if (numBytes > 4) {
+ p = PlatformDependent.UNSAFE.getLong(base, offset);
+ p = p & ((1L << numBytes * 8) - 1);
+ } else if (numBytes > 0) {
+ p = (long) PlatformDependent.UNSAFE.getInt(base, offset);
+ p = p & ((1L << numBytes * 8) - 1);
+ } else {
+ p = 0;
+ }
+ p = java.lang.Long.reverseBytes(p);
+ return p;
+ }
+
/**
* Returns the underline bytes, will be a copy of it if it's part of another array.
*/
@@ -300,13 +339,13 @@ public UTF8String trimRight() {
}
public UTF8String reverse() {
- byte[] bytes = getBytes();
- byte[] result = new byte[bytes.length];
+ byte[] result = new byte[this.numBytes];
int i = 0; // position in byte
while (i < numBytes) {
int len = numBytesForFirstByte(getByte(i));
- System.arraycopy(bytes, i, result, result.length - i - len, len);
+ copyMemory(this.base, this.offset + i, result,
+ BYTE_ARRAY_OFFSET + result.length - i - len, len);
i += len;
}
@@ -316,11 +355,11 @@ public UTF8String reverse() {
public UTF8String repeat(int times) {
if (times <=0) {
- return fromBytes(new byte[0]);
+ return EMPTY_UTF8;
}
byte[] newBytes = new byte[numBytes * times];
- System.arraycopy(getBytes(), 0, newBytes, 0, numBytes);
+ copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes);
int copied = 1;
while (copied < times) {
@@ -385,16 +424,15 @@ public UTF8String rpad(int len, UTF8String pad) {
UTF8String remain = pad.substring(0, spaces - padChars * count);
byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes];
- System.arraycopy(getBytes(), 0, data, 0, this.numBytes);
+ copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes);
int offset = this.numBytes;
int idx = 0;
- byte[] padBytes = pad.getBytes();
while (idx < count) {
- System.arraycopy(padBytes, 0, data, offset, pad.numBytes);
+ copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
++idx;
offset += pad.numBytes;
}
- System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes);
+ copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
return UTF8String.fromBytes(data);
}
@@ -421,15 +459,14 @@ public UTF8String lpad(int len, UTF8String pad) {
int offset = 0;
int idx = 0;
- byte[] padBytes = pad.getBytes();
while (idx < count) {
- System.arraycopy(padBytes, 0, data, offset, pad.numBytes);
+ copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
++idx;
offset += pad.numBytes;
}
- System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes);
+ copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
offset += remain.numBytes;
- System.arraycopy(getBytes(), 0, data, offset, numBytes());
+ copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes());
return UTF8String.fromBytes(data);
}
@@ -454,9 +491,9 @@ public static UTF8String concat(UTF8String... inputs) {
int offset = 0;
for (int i = 0; i < inputs.length; i++) {
int len = inputs[i].numBytes;
- PlatformDependent.copyMemory(
+ copyMemory(
inputs[i].base, inputs[i].offset,
- result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
+ result, BYTE_ARRAY_OFFSET + offset,
len);
offset += len;
}
@@ -494,7 +531,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
for (int i = 0, j = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
int len = inputs[i].numBytes;
- PlatformDependent.copyMemory(
+ copyMemory(
inputs[i].base, inputs[i].offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
len);
@@ -503,7 +540,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
j++;
// Add separator if this is not the last input.
if (j < numInputs) {
- PlatformDependent.copyMemory(
+ copyMemory(
separator.base, separator.offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
separator.numBytes);
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index dae47e4bab0cb..0be94ad371255 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -43,6 +43,7 @@ public abstract class AbstractBytesToBytesMapSuite {
private TaskMemoryManager memoryManager;
private TaskMemoryManager sizeLimitedMemoryManager;
+ private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
@Before
public void setup() {
@@ -110,7 +111,7 @@ private static boolean arrayEquals(
@Test
public void emptyMap() {
- BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64);
+ BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES);
try {
Assert.assertEquals(0, map.size());
final int keyLengthInWords = 10;
@@ -125,7 +126,7 @@ public void emptyMap() {
@Test
public void setAndRetrieveAKey() {
- BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64);
+ BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES);
final int recordLengthWords = 10;
final int recordLengthBytes = recordLengthWords * 8;
final byte[] keyData = getRandomByteArray(recordLengthWords);
@@ -177,7 +178,7 @@ public void setAndRetrieveAKey() {
@Test
public void iteratorTest() throws Exception {
final int size = 4096;
- BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2);
+ BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2, PAGE_SIZE_BYTES);
try {
for (long i = 0; i < size; i++) {
final long[] value = new long[] { i };
@@ -235,7 +236,7 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception {
final int NUM_ENTRIES = 1000 * 1000;
final int KEY_LENGTH = 16;
final int VALUE_LENGTH = 40;
- final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES);
+ final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
// Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
// pages won't be evenly-divisible by records of this size, which will cause us to waste some
// space at the end of the page. This is necessary in order for us to take the end-of-record
@@ -304,7 +305,7 @@ public void randomizedStressTest() {
// Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
// into ByteBuffers in order to use them as keys here.
final Map expected = new HashMap();
- final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size);
+ final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size, PAGE_SIZE_BYTES);
try {
// Fill the map to 90% full so that we can trigger probing
@@ -353,14 +354,15 @@ public void randomizedStressTest() {
@Test
public void initialCapacityBoundsChecking() {
try {
- new BytesToBytesMap(sizeLimitedMemoryManager, 0);
+ new BytesToBytesMap(sizeLimitedMemoryManager, 0, PAGE_SIZE_BYTES);
Assert.fail("Expected IllegalArgumentException to be thrown");
} catch (IllegalArgumentException e) {
// expected exception
}
try {
- new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1);
+ new BytesToBytesMap(
+ sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES);
Assert.fail("Expected IllegalArgumentException to be thrown");
} catch (IllegalArgumentException e) {
// expected exception
@@ -368,15 +370,15 @@ public void initialCapacityBoundsChecking() {
// Can allocate _at_ the max capacity
BytesToBytesMap map =
- new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY);
+ new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY, PAGE_SIZE_BYTES);
map.free();
}
@Test
public void resizingLargeMap() {
// As long as a map's capacity is below the max, we should be able to resize up to the max
- BytesToBytesMap map =
- new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64);
+ BytesToBytesMap map = new BytesToBytesMap(
+ sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64, PAGE_SIZE_BYTES);
map.growAndRehash();
map.free();
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
index d29517cda66a3..e6733a7aae6f5 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
@@ -20,16 +20,16 @@
import org.junit.Test;
import static junit.framework.Assert.*;
-import static org.apache.spark.unsafe.types.Interval.*;
+import static org.apache.spark.unsafe.types.CalendarInterval.*;
public class IntervalSuite {
@Test
public void equalsTest() {
- Interval i1 = new Interval(3, 123);
- Interval i2 = new Interval(3, 321);
- Interval i3 = new Interval(1, 123);
- Interval i4 = new Interval(3, 123);
+ CalendarInterval i1 = new CalendarInterval(3, 123);
+ CalendarInterval i2 = new CalendarInterval(3, 321);
+ CalendarInterval i3 = new CalendarInterval(1, 123);
+ CalendarInterval i4 = new CalendarInterval(3, 123);
assertNotSame(i1, i2);
assertNotSame(i1, i3);
@@ -39,21 +39,21 @@ public void equalsTest() {
@Test
public void toStringTest() {
- Interval i;
+ CalendarInterval i;
- i = new Interval(34, 0);
+ i = new CalendarInterval(34, 0);
assertEquals(i.toString(), "interval 2 years 10 months");
- i = new Interval(-34, 0);
+ i = new CalendarInterval(-34, 0);
assertEquals(i.toString(), "interval -2 years -10 months");
- i = new Interval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123);
+ i = new CalendarInterval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123);
assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds");
- i = new Interval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123);
+ i = new CalendarInterval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123);
assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds");
- i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123);
+ i = new CalendarInterval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123);
assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds");
}
@@ -72,33 +72,33 @@ public void fromStringTest() {
String input;
input = "interval -5 years 23 month";
- Interval result = new Interval(-5 * 12 + 23, 0);
- assertEquals(Interval.fromString(input), result);
+ CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0);
+ assertEquals(CalendarInterval.fromString(input), result);
input = "interval -5 years 23 month ";
- assertEquals(Interval.fromString(input), result);
+ assertEquals(CalendarInterval.fromString(input), result);
input = " interval -5 years 23 month ";
- assertEquals(Interval.fromString(input), result);
+ assertEquals(CalendarInterval.fromString(input), result);
// Error cases
input = "interval 3month 1 hour";
- assertEquals(Interval.fromString(input), null);
+ assertEquals(CalendarInterval.fromString(input), null);
input = "interval 3 moth 1 hour";
- assertEquals(Interval.fromString(input), null);
+ assertEquals(CalendarInterval.fromString(input), null);
input = "interval";
- assertEquals(Interval.fromString(input), null);
+ assertEquals(CalendarInterval.fromString(input), null);
input = "int";
- assertEquals(Interval.fromString(input), null);
+ assertEquals(CalendarInterval.fromString(input), null);
input = "";
- assertEquals(Interval.fromString(input), null);
+ assertEquals(CalendarInterval.fromString(input), null);
input = null;
- assertEquals(Interval.fromString(input), null);
+ assertEquals(CalendarInterval.fromString(input), null);
}
@Test
@@ -106,18 +106,18 @@ public void addTest() {
String input = "interval 3 month 1 hour";
String input2 = "interval 2 month 100 hour";
- Interval interval = Interval.fromString(input);
- Interval interval2 = Interval.fromString(input2);
+ CalendarInterval interval = CalendarInterval.fromString(input);
+ CalendarInterval interval2 = CalendarInterval.fromString(input2);
- assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR));
+ assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR));
input = "interval -10 month -81 hour";
input2 = "interval 75 month 200 hour";
- interval = Interval.fromString(input);
- interval2 = Interval.fromString(input2);
+ interval = CalendarInterval.fromString(input);
+ interval2 = CalendarInterval.fromString(input2);
- assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR));
+ assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR));
}
@Test
@@ -125,25 +125,25 @@ public void subtractTest() {
String input = "interval 3 month 1 hour";
String input2 = "interval 2 month 100 hour";
- Interval interval = Interval.fromString(input);
- Interval interval2 = Interval.fromString(input2);
+ CalendarInterval interval = CalendarInterval.fromString(input);
+ CalendarInterval interval2 = CalendarInterval.fromString(input2);
- assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR));
+ assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR));
input = "interval -10 month -81 hour";
input2 = "interval 75 month 200 hour";
- interval = Interval.fromString(input);
- interval2 = Interval.fromString(input2);
+ interval = CalendarInterval.fromString(input);
+ interval2 = CalendarInterval.fromString(input2);
- assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR));
+ assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR));
}
private void testSingleUnit(String unit, int number, int months, long microseconds) {
String input1 = "interval " + number + " " + unit;
String input2 = "interval " + number + " " + unit + "s";
- Interval result = new Interval(months, microseconds);
- assertEquals(Interval.fromString(input1), result);
- assertEquals(Interval.fromString(input2), result);
+ CalendarInterval result = new CalendarInterval(months, microseconds);
+ assertEquals(CalendarInterval.fromString(input1), result);
+ assertEquals(CalendarInterval.fromString(input2), result);
}
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index e2a5628ff4d93..f2cc19ca6b172 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -63,8 +63,27 @@ public void emptyStringTest() {
assertEquals(0, EMPTY_UTF8.numBytes());
}
+ @Test
+ public void prefix() {
+ assertTrue(fromString("a").getPrefix() - fromString("b").getPrefix() < 0);
+ assertTrue(fromString("ab").getPrefix() - fromString("b").getPrefix() < 0);
+ assertTrue(
+ fromString("abbbbbbbbbbbasdf").getPrefix() - fromString("bbbbbbbbbbbbasdf").getPrefix() < 0);
+ assertTrue(fromString("").getPrefix() - fromString("a").getPrefix() < 0);
+ assertTrue(fromString("你好").getPrefix() - fromString("世界").getPrefix() > 0);
+
+ byte[] buf1 = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ byte[] buf2 = {1, 2, 3};
+ UTF8String str1 = UTF8String.fromBytes(buf1, 0, 3);
+ UTF8String str2 = UTF8String.fromBytes(buf1, 0, 8);
+ UTF8String str3 = UTF8String.fromBytes(buf2);
+ assertTrue(str1.getPrefix() - str2.getPrefix() < 0);
+ assertEquals(str1.getPrefix(), str3.getPrefix());
+ }
+
@Test
public void compareTo() {
+ assertTrue(fromString("").compareTo(fromString("a")) < 0);
assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0);
assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0);
assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0);
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 44acc7374d024..1d67b3ebb51b7 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -229,7 +229,11 @@ private[spark] class ApplicationMaster(
sparkContextRef.compareAndSet(sc, null)
}
- private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = {
+ private def registerAM(
+ _rpcEnv: RpcEnv,
+ driverRef: RpcEndpointRef,
+ uiAddress: String,
+ securityMgr: SecurityManager) = {
val sc = sparkContextRef.get()
val appId = client.getAttemptId().getApplicationId().toString()
@@ -246,6 +250,7 @@ private[spark] class ApplicationMaster(
RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt),
CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
allocator = client.register(driverUrl,
+ driverRef,
yarnConf,
_sparkConf,
if (sc != null) sc.preferredNodeLocationData else Map(),
@@ -262,17 +267,20 @@ private[spark] class ApplicationMaster(
*
* In cluster mode, the AM and the driver belong to same process
* so the AMEndpoint need not monitor lifecycle of the driver.
+ *
+ * @return A reference to the driver's RPC endpoint.
*/
private def runAMEndpoint(
host: String,
port: String,
- isClusterMode: Boolean): Unit = {
+ isClusterMode: Boolean): RpcEndpointRef = {
val driverEndpoint = rpcEnv.setupEndpointRef(
SparkEnv.driverActorSystemName,
RpcAddress(host, port.toInt),
YarnSchedulerBackend.ENDPOINT_NAME)
amEndpoint =
rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode))
+ driverEndpoint
}
private def runDriver(securityMgr: SecurityManager): Unit = {
@@ -290,11 +298,11 @@ private[spark] class ApplicationMaster(
"Timed out waiting for SparkContext.")
} else {
rpcEnv = sc.env.rpcEnv
- runAMEndpoint(
+ val driverRef = runAMEndpoint(
sc.getConf.get("spark.driver.host"),
sc.getConf.get("spark.driver.port"),
isClusterMode = true)
- registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr)
+ registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr)
userClassThread.join()
}
}
@@ -302,9 +310,9 @@ private[spark] class ApplicationMaster(
private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
val port = sparkConf.getInt("spark.yarn.am.port", 0)
rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr)
- waitForSparkDriver()
+ val driverRef = waitForSparkDriver()
addAmIpFilter()
- registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
+ registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
// In client mode the actor will stop the reporter thread.
reporterThread.join()
@@ -428,7 +436,7 @@ private[spark] class ApplicationMaster(
}
}
- private def waitForSparkDriver(): Unit = {
+ private def waitForSparkDriver(): RpcEndpointRef = {
logInfo("Waiting for Spark driver to be reachable.")
var driverUp = false
val hostport = args.userArgs(0)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index bc28ce5eeae72..4ac3397f1ad28 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -767,7 +767,7 @@ private[spark] class Client(
amContainer.setCommands(printableCommands)
logDebug("===============================================================================")
- logDebug("Yarn AM launch context:")
+ logDebug("YARN AM launch context:")
logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}")
logDebug(" env:")
launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") }
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 78e27fb7f3337..52580deb372c2 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -86,10 +86,17 @@ class ExecutorRunnable(
val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores,
appId, localResources)
- logInfo(s"Setting up executor with environment: $env")
- logInfo("Setting up executor with commands: " + commands)
- ctx.setCommands(commands)
+ logInfo(s"""
+ |===============================================================================
+ |YARN executor launch context:
+ | env:
+ |${env.map { case (k, v) => s" $k -> $v\n" }.mkString}
+ | command:
+ | ${commands.mkString(" ")}
+ |===============================================================================
+ """.stripMargin)
+ ctx.setCommands(commands)
ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr))
// If external shuffle service is enabled, register with the Yarn shuffle service already
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 6c103394af098..59caa787b6e20 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -36,6 +36,9 @@ import org.apache.log4j.{Level, Logger}
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
/**
* YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding
@@ -52,6 +55,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
*/
private[yarn] class YarnAllocator(
driverUrl: String,
+ driverRef: RpcEndpointRef,
conf: Configuration,
sparkConf: SparkConf,
amClient: AMRMClient[ContainerRequest],
@@ -88,6 +92,9 @@ private[yarn] class YarnAllocator(
// Visible for testing.
private[yarn] val executorIdToContainer = new HashMap[String, Container]
+ private var numUnexpectedContainerRelease = 0L
+ private val containerIdToExecutorId = new HashMap[ContainerId, String]
+
// Executor memory in MB.
protected val executorMemory = args.executorMemory
// Additional memory overhead.
@@ -184,6 +191,7 @@ private[yarn] class YarnAllocator(
def killExecutor(executorId: String): Unit = synchronized {
if (executorIdToContainer.contains(executorId)) {
val container = executorIdToContainer.remove(executorId).get
+ containerIdToExecutorId.remove(container.getId)
internalReleaseContainer(container)
numExecutorsRunning -= 1
} else {
@@ -383,6 +391,7 @@ private[yarn] class YarnAllocator(
logInfo("Launching container %s for on host %s".format(containerId, executorHostname))
executorIdToContainer(executorId) = container
+ containerIdToExecutorId(container.getId) = executorId
val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname,
new HashSet[ContainerId])
@@ -413,12 +422,8 @@ private[yarn] class YarnAllocator(
private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = {
for (completedContainer <- completedContainers) {
val containerId = completedContainer.getContainerId
-
- if (releasedContainers.contains(containerId)) {
- // Already marked the container for release, so remove it from
- // `releasedContainers`.
- releasedContainers.remove(containerId)
- } else {
+ val alreadyReleased = releasedContainers.remove(containerId)
+ if (!alreadyReleased) {
// Decrement the number of executors running. The next iteration of
// the ApplicationMaster's reporting thread will take care of allocating.
numExecutorsRunning -= 1
@@ -460,6 +465,18 @@ private[yarn] class YarnAllocator(
allocatedContainerToHostMap.remove(containerId)
}
+
+ containerIdToExecutorId.remove(containerId).foreach { eid =>
+ executorIdToContainer.remove(eid)
+
+ if (!alreadyReleased) {
+ // The executor could have gone away (like no route to host, node failure, etc)
+ // Notify backend about the failure of the executor
+ numUnexpectedContainerRelease += 1
+ driverRef.send(RemoveExecutor(eid,
+ s"Yarn deallocated the executor $eid (container $containerId)"))
+ }
+ }
}
}
@@ -467,6 +484,9 @@ private[yarn] class YarnAllocator(
releasedContainers.add(container.getId())
amClient.releaseAssignedContainer(container.getId())
}
+
+ private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease
+
}
private object YarnAllocator {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index 7f533ee55e8bb..4999f9c06210a 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -33,6 +33,7 @@ import org.apache.hadoop.yarn.util.ConverterUtils
import org.apache.hadoop.yarn.webapp.util.WebAppUtils
import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.SplitInfo
import org.apache.spark.util.Utils
@@ -56,6 +57,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg
*/
def register(
driverUrl: String,
+ driverRef: RpcEndpointRef,
conf: YarnConfiguration,
sparkConf: SparkConf,
preferredNodeLocations: Map[String, Set[SplitInfo]],
@@ -73,7 +75,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg
amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)
registered = true
}
- new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr)
+ new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args,
+ securityMgr)
}
/**
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index 37a789fcd375b..58318bf9bcc08 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -27,10 +27,14 @@ import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.scalatest.{BeforeAndAfterEach, Matchers}
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+import org.mockito.Mockito._
+
import org.apache.spark.{SecurityManager, SparkFunSuite}
import org.apache.spark.SparkConf
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
import org.apache.spark.deploy.yarn.YarnAllocator._
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.SplitInfo
class MockResolver extends DNSToSwitchMapping {
@@ -90,6 +94,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
"--class", "SomeClass")
new YarnAllocator(
"not used",
+ mock(classOf[RpcEndpointRef]),
conf,
sparkConf,
rmClient,
@@ -230,6 +235,30 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
handler.getNumPendingAllocate should be (1)
}
+ test("lost executor removed from backend") {
+ val handler = createAllocator(4)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getNumPendingAllocate should be (4)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host2")
+ handler.handleAllocatedContainers(Array(container1, container2))
+
+ handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map())
+
+ val statuses = Seq(container1, container2).map { c =>
+ ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1)
+ }
+ handler.updateResourceRequests()
+ handler.processCompletedContainers(statuses.toSeq)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getNumPendingAllocate should be (2)
+ handler.getNumExecutorsFailed should be (2)
+ handler.getNumUnexpectedContainerRelease should be (2)
+ }
+
test("memory exceeded diagnostic regexes") {
val diagnostics =
"Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " +