From ed413bcc78d8d97a1a0cd0871d7a20f7170476d0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 29 Jun 2015 11:41:26 -0700 Subject: [PATCH 001/274] [SPARK-8692] [SQL] re-order the case statements that handling catalyst data types use same order: boolean, byte, short, int, date, long, timestamp, float, double, string, binary, decimal. Then we can easily check whether some data types are missing by just one glance, and make sure we handle data/timestamp just as int/long. Author: Wenchen Fan Closes #7073 from cloud-fan/fix-date and squashes the following commits: 463044d [Wenchen Fan] fix style 51cd347 [Wenchen Fan] refactor handling of date and timestmap --- .../expressions/SpecificMutableRow.scala | 12 +-- .../expressions/UnsafeRowConverter.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 6 +- .../spark/sql/columnar/ColumnAccessor.scala | 42 +++++----- .../spark/sql/columnar/ColumnBuilder.scala | 30 +++---- .../spark/sql/columnar/ColumnStats.scala | 74 ++++++++--------- .../spark/sql/columnar/ColumnType.scala | 10 +-- .../sql/execution/SparkSqlSerializer2.scala | 82 ++++++------------- .../sql/parquet/ParquetTableSupport.scala | 34 ++++---- .../spark/sql/parquet/ParquetTypes.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 9 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 54 ++++++------ .../sql/columnar/ColumnarTestUtils.scala | 8 +- .../NullableColumnAccessorSuite.scala | 6 +- .../columnar/NullableColumnBuilderSuite.scala | 6 +- 15 files changed, 174 insertions(+), 209 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 53fedb531cfb2..3928c0f2ffdaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -196,15 +196,15 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR def this(dataTypes: Seq[DataType]) = this( dataTypes.map { - case IntegerType => new MutableInt + case BooleanType => new MutableBoolean case ByteType => new MutableByte - case FloatType => new MutableFloat case ShortType => new MutableShort + // We use INT for DATE internally + case IntegerType | DateType => new MutableInt + // We use Long for Timestamp internally + case LongType | TimestampType => new MutableLong + case FloatType => new MutableFloat case DoubleType => new MutableDouble - case BooleanType => new MutableBoolean - case LongType => new MutableLong - case DateType => new MutableInt // We use INT for DATE internally - case TimestampType => new MutableLong // We use Long for Timestamp internally case _ => new MutableAny }.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 89adaf053b1a4..b61d490429e4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -128,14 +128,12 @@ private object UnsafeColumnWriter { case BooleanType => BooleanUnsafeColumnWriter case ByteType => ByteUnsafeColumnWriter case ShortType => ShortUnsafeColumnWriter - case IntegerType => IntUnsafeColumnWriter - case LongType => LongUnsafeColumnWriter + case IntegerType | DateType => IntUnsafeColumnWriter + case LongType | TimestampType => LongUnsafeColumnWriter case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter case BinaryType => BinaryUnsafeColumnWriter - case DateType => IntUnsafeColumnWriter - case TimestampType => LongUnsafeColumnWriter case t => throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e20e3a9dca502..57e0bede5db20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -120,15 +120,13 @@ class CodeGenContext { case BooleanType => JAVA_BOOLEAN case ByteType => JAVA_BYTE case ShortType => JAVA_SHORT - case IntegerType => JAVA_INT - case LongType => JAVA_LONG + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType - case DateType => JAVA_INT - case TimestampType => JAVA_LONG case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 64449b2659b4b..931469bed634a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -71,44 +71,44 @@ private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) -private[sql] class IntColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, INT) +private[sql] class ByteColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BYTE) private[sql] class ShortColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, SHORT) +private[sql] class IntColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, INT) + private[sql] class LongColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, LONG) -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BYTE) - -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DOUBLE) - private[sql] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) -private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) - extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) +private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DOUBLE) private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) -private[sql] class DateColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DATE) - -private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, TIMESTAMP) - private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) with NullableColumnAccessor +private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) + extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) + private[sql] class GenericColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) with NullableColumnAccessor +private[sql] class DateColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DATE) + +private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, TIMESTAMP) + private[sql] object ColumnAccessor { def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val dup = buffer.duplicate().order(ByteOrder.nativeOrder) @@ -118,17 +118,17 @@ private[sql] object ColumnAccessor { dup.getInt() dataType match { + case BooleanType => new BooleanColumnAccessor(dup) + case ByteType => new ByteColumnAccessor(dup) + case ShortType => new ShortColumnAccessor(dup) case IntegerType => new IntColumnAccessor(dup) + case DateType => new DateColumnAccessor(dup) case LongType => new LongColumnAccessor(dup) + case TimestampType => new TimestampColumnAccessor(dup) case FloatType => new FloatColumnAccessor(dup) case DoubleType => new DoubleColumnAccessor(dup) - case BooleanType => new BooleanColumnAccessor(dup) - case ByteType => new ByteColumnAccessor(dup) - case ShortType => new ShortColumnAccessor(dup) case StringType => new StringColumnAccessor(dup) case BinaryType => new BinaryColumnAccessor(dup) - case DateType => new DateColumnAccessor(dup) - case TimestampType => new TimestampColumnAccessor(dup) case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnAccessor(dup, precision, scale) case _ => new GenericColumnAccessor(dup) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 1949625699ca8..087c52239713d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -94,17 +94,21 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) + private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) + +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) private[sql] class FixedDecimalColumnBuilder( precision: Int, @@ -113,19 +117,15 @@ private[sql] class FixedDecimalColumnBuilder( new FixedDecimalColumnStats, FIXED_DECIMAL(precision, scale)) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +// TODO (lian) Add support for array, struct and map +private[sql] class GenericColumnBuilder + extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) - -// TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) - private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 @@ -151,17 +151,17 @@ private[sql] object ColumnBuilder { columnName: String = "", useCompression: Boolean = false): ColumnBuilder = { val builder: ColumnBuilder = dataType match { + case BooleanType => new BooleanColumnBuilder + case ByteType => new ByteColumnBuilder + case ShortType => new ShortColumnBuilder case IntegerType => new IntColumnBuilder + case DateType => new DateColumnBuilder case LongType => new LongColumnBuilder + case TimestampType => new TimestampColumnBuilder case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder - case BooleanType => new BooleanColumnBuilder - case ByteType => new ByteColumnBuilder - case ShortType => new ShortColumnBuilder case StringType => new StringColumnBuilder case BinaryType => new BinaryColumnBuilder - case DateType => new DateColumnBuilder - case TimestampType => new TimestampColumnBuilder case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnBuilder(precision, scale) case _ => new GenericColumnBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 1bce214d1d6c3..00374d1fa3ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -132,17 +132,17 @@ private[sql] class ShortColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class LongColumnStats extends ColumnStats { - protected var upper = Long.MinValue - protected var lower = Long.MaxValue +private[sql] class IntColumnStats extends ColumnStats { + protected var upper = Int.MinValue + protected var lower = Int.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getLong(ordinal) + val value = row.getInt(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += LONG.defaultSize + sizeInBytes += INT.defaultSize } } @@ -150,17 +150,17 @@ private[sql] class LongColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DoubleColumnStats extends ColumnStats { - protected var upper = Double.MinValue - protected var lower = Double.MaxValue +private[sql] class LongColumnStats extends ColumnStats { + protected var upper = Long.MinValue + protected var lower = Long.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getDouble(ordinal) + val value = row.getLong(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += DOUBLE.defaultSize + sizeInBytes += LONG.defaultSize } } @@ -186,35 +186,17 @@ private[sql] class FloatColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class FixedDecimalColumnStats extends ColumnStats { - protected var upper: Decimal = null - protected var lower: Decimal = null - - override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Decimal] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += FIXED_DECIMAL.defaultSize - } - } - - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) -} - -private[sql] class IntColumnStats extends ColumnStats { - protected var upper = Int.MinValue - protected var lower = Int.MaxValue +private[sql] class DoubleColumnStats extends ColumnStats { + protected var upper = Double.MinValue + protected var lower = Double.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getInt(ordinal) + val value = row.getDouble(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += INT.defaultSize + sizeInBytes += DOUBLE.defaultSize } } @@ -240,10 +222,6 @@ private[sql] class StringColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DateColumnStats extends IntColumnStats - -private[sql] class TimestampColumnStats extends LongColumnStats - private[sql] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) @@ -256,6 +234,24 @@ private[sql] class BinaryColumnStats extends ColumnStats { InternalRow(null, null, nullCount, count, sizeInBytes) } +private[sql] class FixedDecimalColumnStats extends ColumnStats { + protected var upper: Decimal = null + protected var lower: Decimal = null + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Decimal] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += FIXED_DECIMAL.defaultSize + } + } + + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) +} + private[sql] class GenericColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) @@ -267,3 +263,7 @@ private[sql] class GenericColumnStats extends ColumnStats { override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, sizeInBytes) } + +private[sql] class DateColumnStats extends IntColumnStats + +private[sql] class TimestampColumnStats extends LongColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 8bf2151e4de68..fc72360c88fe1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -447,17 +447,17 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { private[sql] object ColumnType { def apply(dataType: DataType): ColumnType[_, _] = { dataType match { + case BooleanType => BOOLEAN + case ByteType => BYTE + case ShortType => SHORT case IntegerType => INT + case DateType => DATE case LongType => LONG + case TimestampType => TIMESTAMP case FloatType => FLOAT case DoubleType => DOUBLE - case BooleanType => BOOLEAN - case ByteType => BYTE - case ShortType => SHORT case StringType => STRING case BinaryType => BINARY - case DateType => DATE - case TimestampType => TIMESTAMP case DecimalType.Fixed(precision, scale) if precision < 19 => FIXED_DECIMAL(precision, scale) case _ => GENERIC diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 74a22353b1d27..056d435eecd23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -237,7 +237,7 @@ private[sql] object SparkSqlSerializer2 { out.writeShort(row.getShort(i)) } - case IntegerType => + case IntegerType | DateType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -245,7 +245,7 @@ private[sql] object SparkSqlSerializer2 { out.writeInt(row.getInt(i)) } - case LongType => + case LongType | TimestampType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -269,55 +269,39 @@ private[sql] object SparkSqlSerializer2 { out.writeDouble(row.getDouble(i)) } - case decimal: DecimalType => + case StringType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.apply(i).asInstanceOf[Decimal] - val javaBigDecimal = value.toJavaBigDecimal - // First, write out the unscaled value. - val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + val bytes = row.getAs[UTF8String](i).getBytes out.writeInt(bytes.length) out.write(bytes) - // Then, write out the scale. - out.writeInt(javaBigDecimal.scale()) } - case DateType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeInt(row.getAs[Int](i)) - } - - case TimestampType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeLong(row.getAs[Long](i)) - } - - case StringType => + case BinaryType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[UTF8String](i).getBytes + val bytes = row.getAs[Array[Byte]](i) out.writeInt(bytes.length) out.write(bytes) } - case BinaryType => + case decimal: DecimalType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[Array[Byte]](i) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray out.writeInt(bytes.length) out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) } } i += 1 @@ -364,14 +348,14 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setShort(i, in.readShort()) } - case IntegerType => + case IntegerType | DateType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { mutableRow.setInt(i, in.readInt()) } - case LongType => + case LongType | TimestampType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { @@ -392,53 +376,39 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setDouble(i, in.readDouble()) } - case decimal: DecimalType => + case StringType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { - // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - val unscaledVal = new BigInteger(bytes) - // Then, read the scale. - val scale = in.readInt() - // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) - } - - case DateType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.update(i, in.readInt()) - } - - case TimestampType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.update(i, in.readLong()) + mutableRow.update(i, UTF8String.fromBytes(bytes)) } - case StringType => + case BinaryType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, UTF8String.fromBytes(bytes)) + mutableRow.update(i, bytes) } - case BinaryType => + case decimal: DecimalType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { + // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) } } i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 0d96a1e8070b1..df2a96dfeb619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -198,19 +198,18 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { if (value != null) { schema match { - case StringType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(value.asInstanceOf[Int]) + case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case ByteType => writer.addInteger(value.asInstanceOf[Byte]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) + case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int]) case LongType => writer.addLong(value.asInstanceOf[Long]) case TimestampType => writeTimestamp(value.asInstanceOf[Long]) - case ByteType => writer.addInteger(value.asInstanceOf[Byte]) - case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) - case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) - case DateType => writer.addInteger(value.asInstanceOf[Int]) + case DoubleType => writer.addDouble(value.asInstanceOf[Double]) + case StringType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) + case BinaryType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") @@ -353,19 +352,18 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { record: InternalRow, index: Int): Unit = { ctype match { + case BooleanType => writer.addBoolean(record.getBoolean(index)) + case ByteType => writer.addInteger(record.getByte(index)) + case ShortType => writer.addInteger(record.getShort(index)) + case IntegerType | DateType => writer.addInteger(record.getInt(index)) + case LongType => writer.addLong(record.getLong(index)) + case TimestampType => writeTimestamp(record.getLong(index)) + case FloatType => writer.addFloat(record.getFloat(index)) + case DoubleType => writer.addDouble(record.getDouble(index)) case StringType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(record.getInt(index)) - case ShortType => writer.addInteger(record.getShort(index)) - case LongType => writer.addLong(record.getLong(index)) - case ByteType => writer.addInteger(record.getByte(index)) - case DoubleType => writer.addDouble(record.getDouble(index)) - case FloatType => writer.addFloat(record.getFloat(index)) - case BooleanType => writer.addBoolean(record.getBoolean(index)) - case DateType => writer.addInteger(record.getInt(index)) - case TimestampType => writeTimestamp(record.getLong(index)) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 4d5199a140344..e748bd7857bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.types._ private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = ctype match { - case _: NumericType | BooleanType | StringType | BinaryType => true - case _: DataType => false + case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true + case _ => false } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 1f37455dd0bc4..9bd7b221e93f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -22,19 +22,20 @@ import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, + InternalRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, InternalRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) - testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - InternalRow(Long.MaxValue, Long.MinValue, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 6daddfb2c4804..4d46a657056e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -36,9 +36,9 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 8, - BINARY -> 16, GENERIC -> 16) + BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, + LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -60,27 +60,24 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } } - checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(BYTE, Byte.MaxValue, 1) checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(DATE, Int.MaxValue, 4) checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(TIMESTAMP, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) - checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) - checkActualSize(BOOLEAN, true, 1) + checkActualSize(DOUBLE, Double.MaxValue, 8) checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, 0, 4) - checkActualSize(TIMESTAMP, 0L, 8) - - val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) + checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) + checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) val generic = Map(1 -> "a") checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } - testNativeColumnType[BooleanType.type]( - BOOLEAN, + testNativeColumnType(BOOLEAN)( (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) }, @@ -88,18 +85,23 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { buffer.get() == 1 }) - testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) + testNativeColumnType(BYTE)(_.put(_), _.get) + + testNativeColumnType(SHORT)(_.putShort(_), _.getShort) + + testNativeColumnType(INT)(_.putInt(_), _.getInt) + + testNativeColumnType(DATE)(_.putInt(_), _.getInt) - testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) + testNativeColumnType(LONG)(_.putLong(_), _.getLong) - testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) + testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong) - testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - testNativeColumnType[DecimalType]( - FIXED_DECIMAL(15, 10), + testNativeColumnType(FIXED_DECIMAL(15, 10))( (buffer: ByteBuffer, decimal: Decimal) => { buffer.putLong(decimal.toUnscaledLong) }, @@ -107,10 +109,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { Decimal(buffer.getLong(), 15, 10) }) - testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) - testNativeColumnType[StringType.type]( - STRING, + testNativeColumnType(STRING)( (buffer: ByteBuffer, string: UTF8String) => { val bytes = string.getBytes buffer.putInt(bytes.length) @@ -197,8 +197,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } def testNativeColumnType[T <: AtomicType]( - columnType: NativeColumnType[T], - putter: (ByteBuffer, T#InternalType) => Unit, + columnType: NativeColumnType[T]) + (putter: (ByteBuffer, T#InternalType) => Unit, getter: (ByteBuffer) => T#InternalType): Unit = { testColumnType[T, T#InternalType](columnType, putter, getter) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 7c86eae3f77fd..d9861339739c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -39,18 +39,18 @@ object ColumnarTestUtils { } (columnType match { + case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() + case DATE => Random.nextInt() case LONG => Random.nextLong() + case TIMESTAMP => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() - case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) - case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) - case DATE => Random.nextInt() - case TIMESTAMP => Random.nextLong() + case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 2a6e0c376551a..9eaa769846088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -42,9 +42,9 @@ class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index cb4e9f1eb7f46..17e9ae464bcc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -38,9 +38,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnBuilder(_) } From 3664ee25f0a67de5ba76e9487a55a55216ae589f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 29 Jun 2015 11:53:17 -0700 Subject: [PATCH 002/274] [SPARK-8066, SPARK-8067] [hive] Add support for Hive 1.0, 1.1 and 1.2. Allow HiveContext to connect to metastores of those versions; some new shims had to be added to account for changing internal APIs. A new test was added to exercise the "reset()" path which now also requires a shim; and the test code was changed to use a directory under the build's target to store ivy dependencies. Without that, at least I consistently run into issues with Ivy messing up (or being confused) by my existing caches. Author: Marcelo Vanzin Closes #7026 from vanzin/SPARK-8067 and squashes the following commits: 3e2e67b [Marcelo Vanzin] [SPARK-8066, SPARK-8067] [hive] Add support for Hive 1.0, 1.1 and 1.2. --- .../spark/sql/hive/client/ClientWrapper.scala | 5 +- .../spark/sql/hive/client/HiveShim.scala | 70 ++++++++++++++++++- .../hive/client/IsolatedClientLoader.scala | 13 ++-- .../spark/sql/hive/client/package.scala | 33 +++++++-- .../spark/sql/hive/client/VersionsSuite.scala | 25 +++++-- 5 files changed, 131 insertions(+), 15 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 2f771d76793e5..4c708cec572ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -97,6 +97,9 @@ private[hive] class ClientWrapper( case hive.v12 => new Shim_v0_12() case hive.v13 => new Shim_v0_13() case hive.v14 => new Shim_v0_14() + case hive.v1_0 => new Shim_v1_0() + case hive.v1_1 => new Shim_v1_1() + case hive.v1_2 => new Shim_v1_2() } // Create an internal session state for this ClientWrapper. @@ -456,7 +459,7 @@ private[hive] class ClientWrapper( logDebug(s"Deleting table $t") val table = client.getTable("default", t) client.getIndexes("default", t, 255).foreach { index => - client.dropIndex("default", t, index.getIndexName, true) + shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { client.dropTable("default", t) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index e7c1779f80ce6..1fa9d278e2a57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.client -import java.lang.{Boolean => JBoolean, Integer => JInteger} +import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{Method, Modifier} import java.net.URI import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} @@ -94,6 +94,8 @@ private[client] sealed abstract class Shim { holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit + def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit + protected def findStaticMethod(klass: Class[_], name: String, args: Class[_]*): Method = { val method = findMethod(klass, name, args: _*) require(Modifier.isStatic(method.getModifiers()), @@ -166,6 +168,14 @@ private[client] class Shim_v0_12 extends Shim { JInteger.TYPE, JBoolean.TYPE, JBoolean.TYPE) + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE) override def setCurrentSessionState(state: SessionState): Unit = { // Starting from Hive 0.13, setCurrentSessionState will internally override @@ -234,6 +244,10 @@ private[client] class Shim_v0_12 extends Shim { numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean) } + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + } + } private[client] class Shim_v0_13 extends Shim_v0_12 { @@ -379,3 +393,57 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { TimeUnit.MILLISECONDS).asInstanceOf[Long] } } + +private[client] class Shim_v1_0 extends Shim_v0_14 { + +} + +private[client] class Shim_v1_1 extends Shim_v1_0 { + + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + } + +} + +private[client] class Shim_v1_2 extends Shim_v1_1 { + + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, + 0: JLong) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 0934ad5034671..3d609a66f3664 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -41,9 +41,11 @@ private[hive] object IsolatedClientLoader { */ def forVersion( version: String, - config: Map[String, String] = Map.empty): IsolatedClientLoader = synchronized { + config: Map[String, String] = Map.empty, + ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion)) + val files = resolvedVersions.getOrElseUpdate(resolvedVersion, + downloadVersion(resolvedVersion, ivyPath)) new IsolatedClientLoader(hiveVersion(version), files, config) } @@ -51,9 +53,12 @@ private[hive] object IsolatedClientLoader { case "12" | "0.12" | "0.12.0" => hive.v12 case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13 case "14" | "0.14" | "0.14.0" => hive.v14 + case "1.0" | "1.0.0" => hive.v1_0 + case "1.1" | "1.1.0" => hive.v1_1 + case "1.2" | "1.2.0" => hive.v1_2 } - private def downloadVersion(version: HiveVersion): Seq[URL] = { + private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { val hiveArtifacts = version.extraDeps ++ Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ @@ -64,7 +69,7 @@ private[hive] object IsolatedClientLoader { SparkSubmitUtils.resolveMavenCoordinates( hiveArtifacts.mkString(","), Some("http://www.datanucleus.org/downloads/maven2"), - None, + ivyPath, exclusions = version.exclusions) } val allFiles = classpath.split(",").map(new File(_)).toSet diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 27a3d8f5896cc..b48082fe4b363 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -32,13 +32,36 @@ package object client { // Hive 0.14 depends on calcite 0.9.2-incubating-SNAPSHOT which does not exist in // maven central anymore, so override those with a version that exists. // - // org.pentaho:pentaho-aggdesigner-algorithm is also nowhere to be found, so exclude - // it explicitly. If it's needed by the metastore client, users will have to dig it - // out of somewhere and use configuration to point Spark at the correct jars. + // The other excluded dependencies are also nowhere to be found, so exclude them explicitly. If + // they're needed by the metastore client, users will have to dig them out of somewhere and use + // configuration to point Spark at the correct jars. case object v14 extends HiveVersion("0.14.0", - Seq("org.apache.calcite:calcite-core:1.3.0-incubating", + extraDeps = Seq("org.apache.calcite:calcite-core:1.3.0-incubating", "org.apache.calcite:calcite-avatica:1.3.0-incubating"), - Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + exclusions = Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v1_0 extends HiveVersion("1.0.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + // The curator dependency was added to the exclusions here because it seems to confuse the ivy + // library. org.apache.curator:curator is a pom dependency but ivy tries to find the jar for it, + // and fails. + case object v1_1 extends HiveVersion("1.1.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + case object v1_2 extends HiveVersion("1.2.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) } // scalastyle:on diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9a571650b6e25..d52e162acbd04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.client +import java.io.File + import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils @@ -28,6 +30,12 @@ import org.apache.spark.util.Utils * is not fully tested. */ class VersionsSuite extends SparkFunSuite with Logging { + + // Do not use a temp path here to speed up subsequent executions of the unit test during + // development. + private val ivyPath = Some( + new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() @@ -38,7 +46,7 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client + val badClient = IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -67,19 +75,21 @@ class VersionsSuite extends SparkFunSuite with Logging { // TODO: currently only works on mysql where we manually create the schema... ignore("failure sanity check") { val e = intercept[Throwable] { - val badClient = quietly { IsolatedClientLoader.forVersion("13", buildConf()).client } + val badClient = quietly { + IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("12", "13", "14") + private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") private var client: ClientInterface = null versions.foreach { version => test(s"$version: create client") { client = null - client = IsolatedClientLoader.forVersion(version, buildConf()).client + client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client } test(s"$version: createDatabase") { @@ -170,5 +180,12 @@ class VersionsSuite extends SparkFunSuite with Logging { false, false) } + + test(s"$version: create index and reset") { + client.runSqlHive("CREATE TABLE indexed_table (key INT)") + client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + + "as 'COMPACT' WITH DEFERRED REBUILD") + client.reset() + } } } From a5c2961caaafd751f11bdd406bb6885443d7572e Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 29 Jun 2015 11:57:19 -0700 Subject: [PATCH 003/274] [SPARK-8235] [SQL] misc function sha / sha1 Jira: https://issues.apache.org/jira/browse/SPARK-8235 I added the support for sha1. If I understood rxin correctly, sha and sha1 should execute the same algorithm, shouldn't they? Please take a close look on the Python part. This is adopted from #6934 Author: Tarek Auel Author: Tarek Auel Closes #6963 from tarekauel/SPARK-8235 and squashes the following commits: f064563 [Tarek Auel] change to shaHex 7ce3cdc [Tarek Auel] rely on automatic cast a1251d6 [Tarek Auel] Merge remote-tracking branch 'upstream/master' into SPARK-8235 68eb043 [Tarek Auel] added docstring be5aff1 [Tarek Auel] improved error message 7336c96 [Tarek Auel] added type check cf23a80 [Tarek Auel] simplified example ebf75ef [Tarek Auel] [SPARK-8301] updated the python documentation. Removed sha in python and scala 6d6ff0d [Tarek Auel] [SPARK-8233] added docstring ea191a9 [Tarek Auel] [SPARK-8233] fixed signatureof python function. Added expected type to misc e3fd7c3 [Tarek Auel] SPARK[8235] added sha to the list of __all__ e5dad4e [Tarek Auel] SPARK[8235] sha / sha1 --- python/pyspark/sql/functions.py | 14 +++++++++ .../catalyst/analysis/FunctionRegistry.scala | 2 ++ .../spark/sql/catalyst/expressions/misc.scala | 30 ++++++++++++++++++- .../expressions/MiscFunctionsSuite.scala | 8 +++++ .../org/apache/spark/sql/functions.scala | 16 ++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 12 ++++++++ 6 files changed, 81 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7d3d0361610b7..45ecd826bd3bd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -42,6 +42,7 @@ 'monotonicallyIncreasingId', 'rand', 'randn', + 'sha1', 'sha2', 'sparkPartitionId', 'struct', @@ -382,6 +383,19 @@ def sha2(col, numBits): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 457948a800a17..b24064d061533 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -136,6 +136,8 @@ object FunctionRegistry { // misc functions expression[Md5]("md5"), expression[Sha2]("sha2"), + expression[Sha1]("sha1"), + expression[Sha1]("sha"), // aggregate functions expression[Average]("avg"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e80706fc65aff..9a39165a1ff05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,8 +21,9 @@ import java.security.MessageDigest import java.security.NoSuchAlgorithmException import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -140,3 +141,30 @@ case class Sha2(left: Expression, right: Expression) """ } } + +/** + * A function that calculates a sha1 hash value and returns it as a hex string + * For input of type [[BinaryType]] or [[StringType]] + */ +case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => + "org.apache.spark.unsafe.types.UTF8String.fromString" + + s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + ) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 38482c54c61db..36e636b5da6b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -31,6 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(null, BinaryType)), null) } + test("sha1") { + checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "5d211bad8f4ee70e16c7d343a838fc344a1ed961") + checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) + checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + } + test("sha2") { checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 355ce0e3423cf..ef92801548a13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1414,6 +1414,22 @@ object functions { */ def md5(columnName: String): Column = md5(Column(columnName)) + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(e: Column): Column = Sha1(e.expr) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(columnName: String): Column = sha1(Column(columnName)) + /** * Calculates the SHA-2 family of hash functions and returns the value as a hex string. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8baed57a7f129..abfd47c811ed9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -144,6 +144,18 @@ class DataFrameFunctionsSuite extends QueryTest { Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) } + test("misc sha1 function") { + val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + checkAnswer( + df.select(sha1($"a"), sha1("b")), + Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) + + val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + checkAnswer( + dfEmpty.selectExpr("sha1(a)", "sha1(b)"), + Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) + } + test("misc sha2 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( From 492dca3a73e70705b5d5639e8fe4640b80e78d31 Mon Sep 17 00:00:00 2001 From: Vladimir Vladimirov Date: Mon, 29 Jun 2015 12:03:41 -0700 Subject: [PATCH 004/274] [SPARK-8528] Expose SparkContext.applicationId in PySpark Use case - we want to log applicationId (YARN in hour case) to request help with troubleshooting from the DevOps Author: Vladimir Vladimirov Closes #6936 from smartkiwi/master and squashes the following commits: 870338b [Vladimir Vladimirov] this would make doctest to run in python3 0eae619 [Vladimir Vladimirov] Scala doesn't use u'...' for unicode literals 14d77a8 [Vladimir Vladimirov] stop using ELLIPSIS b4ebfc5 [Vladimir Vladimirov] addressed PR feedback - updated docstring 223a32f [Vladimir Vladimirov] fixed test - applicationId is property that returns the string 3221f5a [Vladimir Vladimirov] [SPARK-8528] added documentation for Scala 2cff090 [Vladimir Vladimirov] [SPARK-8528] add applicationId property for SparkContext object in pyspark --- .../scala/org/apache/spark/SparkContext.scala | 8 ++++++++ python/pyspark/context.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c7a7436462083..b3c3bf3746e18 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -315,6 +315,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _dagScheduler = ds } + /** + * A unique identifier for the Spark application. + * Its format depends on the scheduler implementation. + * (i.e. + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + * ) + */ def applicationId: String = _applicationId def applicationAttemptId: Option[String] = _applicationAttemptId diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 90b2fffbb9c7c..d7466729b8f36 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -291,6 +291,21 @@ def version(self): """ return self._jsc.version() + @property + @ignore_unicode_prefix + def applicationId(self): + """ + A unique identifier for the Spark application. + Its format depends on the scheduler implementation. + (i.e. + in case of local spark app something like 'local-1433865536131' + in case of YARN something like 'application_1433865536131_34483' + ) + >>> sc.applicationId # doctest: +ELLIPSIS + u'local-...' + """ + return self._jsc.sc().applicationId() + @property def startTime(self): """Return the epoch time when the Spark Context was started.""" From 94e040d05996111b2b448bcdee1cda184c6d039b Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 29 Jun 2015 12:16:12 -0700 Subject: [PATCH 005/274] [SQL][DOCS] Remove wrong example from DataFrame.scala In DataFrame.scala, there are examples like as follows. ``` * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) * peopleDf($"age" > 15) ``` But, I think the last example doesn't work. Author: Kousuke Saruta Closes #6977 from sarutak/fix-dataframe-example and squashes the following commits: 46efbd7 [Kousuke Saruta] Removed wrong example --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d75d88307562e..986e59133919f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -682,7 +682,6 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 @@ -707,7 +706,6 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 From 637b4eedad84dcff1769454137a64ac70c7f2397 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Mon, 29 Jun 2015 12:25:16 -0700 Subject: [PATCH 006/274] [SPARK-8214] [SQL] Add function hex cc chenghao-intel adrian-wang Author: zhichao.li Closes #6976 from zhichao-li/hex and squashes the following commits: e218d1b [zhichao.li] turn off scalastyle for non-ascii de3f5ea [zhichao.li] non-ascii char cf9c936 [zhichao.li] give separated buffer for each hex method 967ec90 [zhichao.li] Make 'value' as a feild of Hex 3b2fa13 [zhichao.li] tiny fix a647641 [zhichao.li] remove duplicate null check 7cab020 [zhichao.li] tiny refactoring 35ecfe5 [zhichao.li] add function hex --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 86 ++++++++++++++++++- .../expressions/MathFunctionsSuite.scala | 14 ++- .../org/apache/spark/sql/functions.scala | 16 ++++ .../spark/sql/MathExpressionsSuite.scala | 13 +++ 5 files changed, 125 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b24064d061533..b17457d3094c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -113,6 +113,7 @@ object FunctionRegistry { expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Hypot]("hypot"), + expression[Hex]("hex"), expression[Logarithm]("log"), expression[Log]("ln"), expression[Log10]("log10"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 5694afc61be05..4b57ddd9c5768 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} +import java.util.Arrays +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -273,9 +275,6 @@ case class Atan2(left: Expression, right: Expression) } } -case class Hypot(left: Expression, right: Expression) - extends BinaryMathExpression(math.hypot, "HYPOT") - case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -287,6 +286,85 @@ case class Pow(left: Expression, right: Expression) } } +/** + * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. + * Otherwise if the number is a STRING, + * it converts each character into its hexadecimal representation and returns the resulting STRING. + * Negative numbers would be treated as two's complement. + */ +case class Hex(child: Expression) + extends UnaryExpression with Serializable { + + override def dataType: DataType = StringType + + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[StringType] + || child.dataType.isInstanceOf[IntegerType] + || child.dataType.isInstanceOf[LongType] + || child.dataType.isInstanceOf[BinaryType] + || child.dataType == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"hex doesn't accepts ${child.dataType} type") + } + } + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + child.dataType match { + case LongType => hex(num.asInstanceOf[Long]) + case IntegerType => hex(num.asInstanceOf[Integer].toLong) + case BinaryType => hex(num.asInstanceOf[Array[Byte]]) + case StringType => hex(num.asInstanceOf[UTF8String]) + } + } + } + + /** + * Converts every character in s to two hex digits. + */ + private def hex(str: UTF8String): UTF8String = { + hex(str.getBytes) + } + + private def hex(bytes: Array[Byte]): UTF8String = { + doHex(bytes, bytes.length) + } + + private def doHex(bytes: Array[Byte], length: Int): UTF8String = { + val value = new Array[Byte](length * 2) + var i = 0 + while(i < length) { + value(i * 2) = Character.toUpperCase(Character.forDigit( + (bytes(i) & 0xF0) >>> 4, 16)).toByte + value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( + bytes(i) & 0x0F, 16)).toByte + i += 1 + } + UTF8String.fromBytes(value) + } + + private def hex(num: Long): UTF8String = { + // Extract the hex digits of num into value[] from right to left + val value = new Array[Byte](16) + var numBuf = num + var len = 0 + do { + len += 1 + value(value.length - len) = Character.toUpperCase(Character + .forDigit((numBuf & 0xF).toInt, 16)).toByte + numBuf >>>= 4 + } while (numBuf != 0) + UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) + } +} + +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") + case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 0d1d5ebdff2d5..b932d4ab850c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.{DataType, DoubleType, LongType} @@ -226,6 +225,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) } + test("hex") { + checkEvaluation(Hex(Literal(28)), "1C") + checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") + checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") + checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") + checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") + checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Hex(Literal("三重的")), "E4B889E9878DE79A84") + // scalastyle:on + } + test("hypot") { testBinary(Hypot, math.hypot) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ef92801548a13..5422e066afcb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1046,6 +1046,22 @@ object functions { */ def floor(columnName: String): Column = floor(Column(columnName)) + /** + * Computes hex value of the given column + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(column: Column): Column = Hex(column.expr) + + /** + * Computes hex value of the given input + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(colName: String): Column = hex(Column(colName)) + /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 2768d7dfc8030..d6331aa4ff09e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -212,6 +212,19 @@ class MathExpressionsSuite extends QueryTest { ) } + test("hex") { + val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") + checkAnswer(data.select(hex('a)), Seq(Row("1C"))) + checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) + checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) + checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) + checkAnswer(data.selectExpr("hex(d)"), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(cast(d as binary))"), Seq(Row("68656C6C6F"))) + } + test("hypot") { testTwoToOneMathFunction(hypot, hypot, math.hypot) } From c6ba2ea341ad23de265d870669b25e6a41f461e5 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 29 Jun 2015 12:46:33 -0700 Subject: [PATCH 007/274] [SPARK-7862] [SQL] Disable the error message redirect to stderr This is a follow up of #6404, the ScriptTransformation prints the error msg into stderr directly, probably be a disaster for application log. Author: Cheng Hao Closes #6882 from chenghao-intel/verbose and squashes the following commits: bfedd77 [Cheng Hao] revert the write 76ff46b [Cheng Hao] update the CircularBuffer 692b19e [Cheng Hao] check the process exitValue for ScriptTransform 47e0970 [Cheng Hao] Use the RedirectThread instead 1de771d [Cheng Hao] naming the threads in ScriptTransformation 8536e81 [Cheng Hao] disable the error message redirection for stderr --- .../scala/org/apache/spark/util/Utils.scala | 33 ++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 8 +++ .../spark/sql/hive/client/ClientWrapper.scala | 29 ++--------- .../hive/execution/ScriptTransformation.scala | 51 ++++++++++++------- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 5 files changed, 77 insertions(+), 46 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 19157af5b6f4d..a7fc749a2b0c6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2333,3 +2333,36 @@ private[spark] class RedirectThread( } } } + +/** + * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it + * in a circular buffer. The current contents of the buffer can be accessed using + * the toString method. + */ +private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream { + var pos: Int = 0 + var buffer = new Array[Int](sizeInBytes) + + def write(i: Int): Unit = { + buffer(pos) = i + pos = (pos + 1) % buffer.length + } + + override def toString: String = { + val (end, start) = buffer.splitAt(pos) + val input = new java.io.InputStream { + val iterator = (start ++ end).iterator + + def read(): Int = if (iterator.hasNext) iterator.next() else -1 + } + val reader = new BufferedReader(new InputStreamReader(input)) + val stringBuilder = new StringBuilder + var line = reader.readLine() + while (line != null) { + stringBuilder.append(line) + stringBuilder.append("\n") + line = reader.readLine() + } + stringBuilder.toString() + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index a61ea3918f46a..baa4c661cc21e 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -673,4 +673,12 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(!Utils.isInDirectory(nullFile, parentDir)) assert(!Utils.isInDirectory(nullFile, childFile3)) } + + test("circular buffer") { + val buffer = new CircularBuffer(25) + val stream = new java.io.PrintStream(buffer, true, "UTF-8") + + stream.println("test circular test circular test circular test circular test circular") + assert(buffer.toString === "t circular test circular\n") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 4c708cec572ae..cbd2bf6b5eede 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -22,6 +22,8 @@ import java.net.URI import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} import javax.annotation.concurrent.GuardedBy +import org.apache.spark.util.CircularBuffer + import scala.collection.JavaConversions._ import scala.language.reflectiveCalls @@ -66,32 +68,7 @@ private[hive] class ClientWrapper( with Logging { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. - private val outputBuffer = new java.io.OutputStream { - var pos: Int = 0 - var buffer = new Array[Int](10240) - def write(i: Int): Unit = { - buffer(pos) = i - pos = (pos + 1) % buffer.size - } - - override def toString: String = { - val (end, start) = buffer.splitAt(pos) - val input = new java.io.InputStream { - val iterator = (start ++ end).iterator - - def read(): Int = if (iterator.hasNext) iterator.next() else -1 - } - val reader = new BufferedReader(new InputStreamReader(input)) - val stringBuilder = new StringBuilder - var line = reader.readLine() - while(line != null) { - stringBuilder.append(line) - stringBuilder.append("\n") - line = reader.readLine() - } - stringBuilder.toString() - } - } + private val outputBuffer = new CircularBuffer() private val shim = version match { case hive.v12 => new Shim_v0_12() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 611888055d6cf..b967e191c5855 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} /** * Transforms the input by forking and running the specified script. @@ -59,15 +59,13 @@ case class ScriptTransformation( child.execute().mapPartitions { iter => val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) - // redirectError(Redirect.INHERIT) would consume the error output from buffer and - // then print it to stderr (inherit the target from the current Scala process). - // If without this there would be 2 issues: + // We need to start threads connected to the process pipeline: // 1) The error msg generated by the script process would be hidden. // 2) If the error msg is too big to chock up the buffer, the input logic would be hung - builder.redirectError(Redirect.INHERIT) val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream val reader = new BufferedReader(new InputStreamReader(inputStream)) val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) @@ -152,29 +150,43 @@ case class ScriptTransformation( val dataOutputStream = new DataOutputStream(outputStream) val outputProjection = new InterpretedProjection(input, child.output) + // TODO make the 2048 configurable? + val stderrBuffer = new CircularBuffer(2048) + // Consume the error stream from the pipeline, otherwise it will be blocked if + // the pipeline is full. + new RedirectThread(errorStream, // input stream from the pipeline + stderrBuffer, // output to a circular buffer + "Thread-ScriptTransformation-STDERR-Consumer").start() + // Put the write(output to the pipeline) into a single thread // and keep the collector as remain in the main thread. // otherwise it will causes deadlock if the data size greater than // the pipeline / buffer capacity. new Thread(new Runnable() { override def run(): Unit = { - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize( - row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + Utils.tryWithSafeFinally { + iter + .map(outputProjection) + .foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + + outputStream.write(data) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } + } + outputStream.close() + } { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer } } - outputStream.close() } - }).start() + }, "Thread-ScriptTransformation-Feed").start() iterator } @@ -278,3 +290,4 @@ case class HiveScriptIOSchema ( } } } + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f0aad8dbbe64d..9f7e58f890241 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -653,7 +653,7 @@ class SQLQuerySuite extends QueryTest { .queryExecution.toRdd.count()) } - ignore("test script transform for stderr") { + test("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === From be7ef067620408859144e0244b0f1b8eb56faa86 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 13:15:04 -0700 Subject: [PATCH 008/274] [SPARK-8681] fixed wrong ordering of columns in crosstab I specifically randomized the test. What crosstab does is equivalent to a countByKey, therefore if this test fails again for any reason, we will know that we hit a corner case or something. cc rxin marmbrus Author: Burak Yavuz Closes #7060 from brkyvz/crosstab-fixes and squashes the following commits: 0a65234 [Burak Yavuz] addressed comments v1 d96da7e [Burak Yavuz] fixed wrong ordering of columns in crosstab --- .../sql/execution/stat/StatFunctions.scala | 8 ++++-- .../apache/spark/sql/DataFrameStatSuite.scala | 28 ++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 042e2c9cbb22e..b624ef7e8fa1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -111,7 +111,7 @@ private[sql] object StatFunctions extends Logging { "the pairs. Please try reducing the amount of distinct items in your columns.") } // get the distinct values of column 2, so that we can make them the column names - val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") @@ -120,14 +120,16 @@ private[sql] object StatFunctions extends Logging { rows.foreach { (row: Row) => // row.get(0) is column 1 // row.get(1) is column 2 - // row.get(3) is the frequency + // row.get(2) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts countsRow.update(0, UTF8String.fromString(col1Item.toString)) countsRow }.toSeq - val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq + // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in + // SPARK-8681. We need to explicitly sort by the column index and assign the column names. + val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType)) val schema = StructType(StructField(tableName, StringType) +: headerNames) new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff899dad72..64ec1a70c47e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Random + import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite @@ -65,22 +67,22 @@ class DataFrameStatSuite extends SparkFunSuite { } test("crosstab") { - val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b") + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") val crosstab = df.stat.crosstab("a", "b") val columnNames = crosstab.schema.fieldNames assert(columnNames(0) === "a_b") - assert(columnNames(1) === "0") - assert(columnNames(2) === "1") - val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) - assert(rows(0).get(0).toString === "0") - assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === 0L) - assert(rows(1).get(0).toString === "1") - assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === 0L) - assert(rows(2).get(0).toString === "2") - assert(rows(2).getLong(1) === 2L) - assert(rows(2).getLong(2) === 1L) + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 to 9) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } + } } test("Frequent Items") { From afae9766f28d2e58297405c39862d20a04267b62 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Jun 2015 13:20:55 -0700 Subject: [PATCH 009/274] [SPARK-8070] [SQL] [PYSPARK] avoid spark jobs in createDataFrame Avoid the unnecessary jobs when infer schema from list. cc yhuai mengxr Author: Davies Liu Closes #6606 from davies/improve_create and squashes the following commits: a5928bf [Davies Liu] Update MimaExcludes.scala 62da911 [Davies Liu] fix mima bab4d7d [Davies Liu] Merge branch 'improve_create' of github.com:davies/spark into improve_create eee44a8 [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 8d9292d [Davies Liu] Update context.py eb24531 [Davies Liu] Update context.py c969997 [Davies Liu] bug fix d5a8ab0 [Davies Liu] fix tests 8c3f10d [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 6ea5925 [Davies Liu] address comments 6ceaeff [Davies Liu] avoid spark jobs in createDataFrame --- python/pyspark/sql/context.py | 64 +++++++++++++++++++++++++---------- python/pyspark/sql/types.py | 48 +++++++++++++++----------- 2 files changed, 75 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index dc239226e6d3c..4dda3b430cfbf 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -203,7 +203,37 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = _infer_schema(first) + if _has_nulltype(schema): + for r in data: + schema = _merge_type(schema, _infer_schema(r)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined after inferring") + return schema + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " @@ -322,6 +352,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): + if not isinstance(data, list): + data = list(data) try: # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) @@ -330,28 +362,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): else: rdd = data - if schema is None: - schema = self._inferSchema(rdd, samplingRatio) + if schema is None or isinstance(schema, (list, tuple)): + if isinstance(data, RDD): + struct = self._inferSchema(rdd, samplingRatio) + else: + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + schema = struct converter = _create_converter(schema) rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - first = rdd.first() - if not isinstance(first, (list, tuple)): - raise TypeError("each row in `rdd` should be list or tuple, " - "but got %r" % type(first)) - row_cls = Row(*schema) - schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) + elif isinstance(schema, StructType): + # take the first few rows to verify schema rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) - for row in rows: - _verify_type(row, schema) + else: + raise TypeError("schema should be StructType or list or None") # convert python objects to sql data converter = _python_to_sql_converter(schema) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 23d9adb0daea1..932686e5e4b01 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -635,7 +635,7 @@ def _need_python_to_sql_conversion(dataType): >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), ... StructField("values", ArrayType(DoubleType(), False), False)]) >>> _need_python_to_sql_conversion(schema0) - False + True >>> _need_python_to_sql_conversion(ExamplePointUDT()) True >>> schema1 = ArrayType(ExamplePointUDT(), False) @@ -647,7 +647,8 @@ def _need_python_to_sql_conversion(dataType): True """ if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + # convert namedtuple or Row into tuple + return True elif isinstance(dataType, ArrayType): return _need_python_to_sql_conversion(dataType.elementType) elif isinstance(dataType, MapType): @@ -688,21 +689,25 @@ def _python_to_sql_converter(dataType): if isinstance(dataType, StructType): names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) + if any(_need_python_to_sql_conversion(t) for t in types): + converters = [_python_to_sql_converter(t) for t in types] + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): + return tuple(c(v) for c, v in zip(converters, obj)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + elif obj is not None: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + else: + def converter(obj): + if isinstance(obj, dict): + return tuple(obj.get(n) for n in names) else: - return tuple(c(v) for c, v in zip(converters, obj)) - elif obj is not None: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return tuple(obj) return converter elif isinstance(dataType, ArrayType): element_converter = _python_to_sql_converter(dataType.elementType) @@ -1027,10 +1032,13 @@ def _verify_type(obj, dataType): _type = type(dataType) assert _type in _acceptable_types, "unknown datatype: %s" % dataType - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) + if _type is StructType: + if not isinstance(obj, (tuple, list)): + raise TypeError("StructType can not accept object in type %s" % type(obj)) + else: + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): for i in obj: From 27ef85451cd237caa7016baa69957a35ab365aa8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Jun 2015 14:07:55 -0700 Subject: [PATCH 010/274] [SPARK-8709] Exclude hadoop-client's mockito-all dependency This patch excludes `hadoop-client`'s dependency on `mockito-all`. As of #7061, Spark depends on `mockito-core` instead of `mockito-all`, so the dependency from Hadoop was leading to test compilation failures for some of the Hadoop 2 SBT builds. Author: Josh Rosen Closes #7090 from JoshRosen/SPARK-8709 and squashes the following commits: e190122 [Josh Rosen] [SPARK-8709] Exclude hadoop-client's mockito-all dependency. --- LICENSE | 2 +- core/pom.xml | 10 ---------- launcher/pom.xml | 6 ------ pom.xml | 8 ++++++++ 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/LICENSE b/LICENSE index 8672be55eca3e..f9e412cade345 100644 --- a/LICENSE +++ b/LICENSE @@ -948,6 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/core/pom.xml b/core/pom.xml index 565437c4861a4..aee0d92620606 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -69,16 +69,6 @@ org.apache.hadoop hadoop-client - - - javax.servlet - servlet-api - - - org.codehaus.jackson - jackson-mapper-asl - - org.apache.spark diff --git a/launcher/pom.xml b/launcher/pom.xml index a853e67f5cf78..2fd768d8119c4 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -68,12 +68,6 @@ org.apache.hadoop hadoop-client test - - - org.codehaus.jackson - jackson-mapper-asl - - diff --git a/pom.xml b/pom.xml index 4c18bd5e42c87..94dd512cfb618 100644 --- a/pom.xml +++ b/pom.xml @@ -747,6 +747,10 @@ asm asm + + org.codehaus.jackson + jackson-mapper-asl + org.ow2.asm asm @@ -759,6 +763,10 @@ commons-logging commons-logging + + org.mockito + mockito-all + org.mortbay.jetty servlet-api-2.5 From f6fc254ec4ce5f103d45da6d007b4066ce751236 Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Mon, 29 Jun 2015 14:15:15 -0700 Subject: [PATCH 011/274] [SPARK-8056][SQL] Design an easier way to construct schema for both Scala and Python I've added functionality to create new StructType similar to how we add parameters to a new SparkContext. I've also added tests for this type of creation. Author: Ilya Ganelin Closes #6686 from ilganeli/SPARK-8056B and squashes the following commits: 27c1de1 [Ilya Ganelin] Rename 467d836 [Ilya Ganelin] Removed from_string in favor of _parse_Datatype_json_value 5fef5a4 [Ilya Ganelin] Updates for type parsing 4085489 [Ilya Ganelin] Style errors 3670cf5 [Ilya Ganelin] added string to DataType conversion 8109e00 [Ilya Ganelin] Fixed error in tests 41ab686 [Ilya Ganelin] Fixed style errors e7ba7e0 [Ilya Ganelin] Moved some python tests to tests.py. Added cleaner handling of null data type and added test for correctness of input format 15868fa [Ilya Ganelin] Fixed python errors b79b992 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-8056B a3369fc [Ilya Ganelin] Fixing space errors e240040 [Ilya Ganelin] Style bab7823 [Ilya Ganelin] Constructor error 73d4677 [Ilya Ganelin] Style 4ed00d9 [Ilya Ganelin] Fixed default arg 67df57a [Ilya Ganelin] Removed Foo 04cbf0c [Ilya Ganelin] Added comments for single object 0484d7a [Ilya Ganelin] Restored second method 6aeb740 [Ilya Ganelin] Style 689e54d [Ilya Ganelin] Style f497e9e [Ilya Ganelin] Got rid of old code e3c7a88 [Ilya Ganelin] Fixed doctest failure a62ccde [Ilya Ganelin] Style 966ac06 [Ilya Ganelin] style checks dabb7e6 [Ilya Ganelin] Added Python tests a3f4152 [Ilya Ganelin] added python bindings and better comments e6e536c [Ilya Ganelin] Added extra space 7529a2e [Ilya Ganelin] Fixed formatting d388f86 [Ilya Ganelin] Fixed small bug c4e3bf5 [Ilya Ganelin] Reverted to using parse. Updated parse to support long d7634b6 [Ilya Ganelin] Reverted to fromString to properly support types 22c39d5 [Ilya Ganelin] replaced FromString with DataTypeParser.parse. Replaced empty constructor initializing a null to have it instead create a new array to allow appends to it. faca398 [Ilya Ganelin] [SPARK-8056] Replaced default argument usage. Updated usage and code for DataType.fromString 1acf76e [Ilya Ganelin] Scala style e31c674 [Ilya Ganelin] Fixed bug in test 8dc0795 [Ilya Ganelin] Added tests for creation of StructType object with new methods fdf7e9f [Ilya Ganelin] [SPARK-8056] Created add methods to facilitate building new StructType objects. --- python/pyspark/sql/tests.py | 29 +++++ python/pyspark/sql/types.py | 52 ++++++++- .../spark/sql/types/DataTypeParser.scala | 2 +- .../apache/spark/sql/types/StructType.scala | 104 +++++++++++++++++- .../spark/sql/types/DataTypeSuite.scala | 31 ++++++ 5 files changed, 212 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ffee43a94baba..34f397d0ffef0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -516,6 +516,35 @@ def test_between_function(self): self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)).collect()) + def test_struct_type(self): + from pyspark.sql.types import StructType, StringType, StructField + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + # Catch exception raised during improper construction + try: + struct1 = StructType().add("name") + self.assertEqual(1, 0) + except ValueError: + self.assertEqual(1, 1) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 932686e5e4b01..ae9344e6106a4 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -355,8 +355,7 @@ class StructType(DataType): This is the data type representing a :class:`Row`. """ - - def __init__(self, fields): + def __init__(self, fields=None): """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) @@ -368,8 +367,53 @@ def __init__(self, fields): >>> struct1 == struct2 False """ - assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" - self.fields = fields + if not fields: + self.fields = [] + else: + self.fields = fields + assert all(isinstance(f, StructField) for f in fields),\ + "fields should be a list of StructField" + + def add(self, field, data_type=None, nullable=True, metadata=None): + """ + Construct a StructType by adding new elements to it to define the schema. The method accepts + either: + a) A single parameter which is a StructField object. + b) Between 2 and 4 parameters as (name, data_type, nullable (optional), + metadata(optional). The data_type parameter may be either a String or a DataType object + + >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + >>> struct2 = StructType([StructField("f1", StringType(), True),\ + StructField("f2", StringType(), True, None)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add(StructField("f1", StringType(), True)) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add("f1", "string", True) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + + :param field: Either the name of the field or a StructField object + :param data_type: If present, the DataType of the StructField to create + :param nullable: Whether the field to add should be nullable (default True) + :param metadata: Any additional metadata (default None) + :return: a new updated StructType + """ + if isinstance(field, StructField): + self.fields.append(field) + else: + if isinstance(field, str) and data_type is None: + raise ValueError("Must specify DataType if passing name of struct_field to create.") + + if isinstance(data_type, str): + data_type_f = _parse_datatype_json_value(data_type) + else: + data_type_f = data_type + self.fields.append(StructField(field, data_type_f, nullable, metadata)) + return self def simpleString(self): return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 04f3379afb38d..6b43224feb1f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -44,7 +44,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)tinyint".r ^^^ ByteType | "(?i)smallint".r ^^^ ShortType | "(?i)double".r ^^^ DoubleType | - "(?i)bigint".r ^^^ LongType | + "(?i)(?:bigint|long)".r ^^^ LongType | "(?i)binary".r ^^^ BinaryType | "(?i)boolean".r ^^^ BooleanType | fixedDecimalType | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 193c08a4d0df7..2db0a359e9db5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -94,7 +94,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ - protected def this() = this(null) + def this() = this(Array.empty[StructField]) /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -103,6 +103,108 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + /** + * Creates a new [[StructType]] by adding a new field. + * {{{ + * val struct = (new StructType) + * .add(StructField("a", IntegerType, true)) + * .add(StructField("b", LongType, false)) + * .add(StructField("c", StringType, true)) + *}}} + */ + def add(field: StructField): StructType = { + StructType(fields :+ field) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType) + * .add("b", LongType) + * .add("c", StringType) + */ + def add(name: String, dataType: DataType): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable = true, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType, true) + * .add("b", LongType, false) + * .add("c", StringType, true) + */ + def add(name: String, dataType: DataType, nullable: Boolean): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata. + * {{{ + * val struct = (new StructType) + * .add("a", IntegerType, true, Metadata.empty) + * .add("b", LongType, false, Metadata.empty) + * .add("c", StringType, true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: DataType, + nullable: Boolean, + metadata: Metadata): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, metadata)) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int") + * .add("b", "long") + * .add("c", "string") + * }}} + */ + def add(name: String, dataType: String): StructType = { + add(name, DataTypeParser.parse(dataType), nullable = true, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int", true) + * .add("b", "long", false) + * .add("c", "string", true) + * }}} + */ + def add(name: String, dataType: String, nullable: Boolean): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata where the + * dataType is specified as a String. + * {{{ + * val struct = (new StructType) + * .add("a", "int", true, Metadata.empty) + * .add("b", "long", false, Metadata.empty) + * .add("c", "string", true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: String, + nullable: Boolean, + metadata: Metadata): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, metadata) + } + /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 077c0ad70ac4f..14e7b4a9561b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -33,6 +33,37 @@ class DataTypeSuite extends SparkFunSuite { assert(MapType(StringType, IntegerType, true) === map) } + test("construct with add") { + val struct = (new StructType) + .add("a", IntegerType, true) + .add("b", LongType, false) + .add("c", StringType, true) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with add from StructField") { + // Test creation from StructField type + val struct = (new StructType) + .add(StructField("a", IntegerType, true)) + .add(StructField("b", LongType, false)) + .add(StructField("c", StringType, true)) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with String DataType") { + // Test creation with DataType as String + val struct = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + + assert(StructField("a", IntegerType, true) === struct("a")) + assert(StructField("b", LongType, false) === struct("b")) + assert(StructField("c", StringType, true) === struct("c")) + } + test("extract fields from a StructType") { val struct = StructType( StructField("a", IntegerType, true) :: From ecd3aacf2805bb231cfb44bab079319cfe73c3f1 Mon Sep 17 00:00:00 2001 From: Ai He Date: Mon, 29 Jun 2015 14:36:26 -0700 Subject: [PATCH 012/274] [SPARK-7810] [PYSPARK] solve python rdd socket connection problem Method "_load_from_socket" in rdd.py cannot load data from jvm socket when ipv6 is used. The current method only works well with ipv4. New modification should work around both two protocols. Author: Ai He Author: AiHe Closes #6338 from AiHe/pyspark-networking-issue and squashes the following commits: d4fc9c4 [Ai He] handle code review 2 e75c5c8 [Ai He] handle code review 5644953 [AiHe] solve python rdd socket connection problem to jvm --- python/pyspark/rdd.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1b64be23a667e..cb20bc8b54027 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -121,10 +121,22 @@ def _parse_memory(s): def _load_from_socket(port, serializer): - sock = socket.socket() - sock.settimeout(3) + sock = None + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(3) + sock.connect(sa) + except socket.error: + sock = None + continue + break + if not sock: + raise Exception("could not open socket") try: - sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536) for item in serializer.load_stream(rf): yield item From c8ae887ef02b8f7e2ad06841719fb12eacf1f7f9 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Mon, 29 Jun 2015 14:45:08 -0700 Subject: [PATCH 013/274] [SPARK-8660][ML] Convert JavaDoc style comments inLogisticRegressionSuite.scala to regular multiline comments, to make copy-pasting R commands easier Converted JavaDoc style comments in mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala to regular multiline comments, to make copy-pasting R commands easier. Author: Rosstin Closes #7096 from Rosstin/SPARK-8660 and squashes the following commits: 242aedd [Rosstin] SPARK-8660, changed comment style from JavaDoc style to normal multiline comment in order to make copypaste into R easier, in file classification/LogisticRegressionSuite.scala 2cd2985 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 21ac1e5 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 6c18058 [Rosstin] fixed minor typos in docs/README.md and docs/api.md --- .../LogisticRegressionSuite.scala | 342 +++++++++--------- 1 file changed, 171 insertions(+), 171 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 5a6265ea992c6..bc6eeac1db5da 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -36,19 +36,19 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) - /** - * Here is the instruction describing how to export the test data into CSV format - * so we can validate the training accuracy compared with R's glmnet package. - * - * import org.apache.spark.mllib.classification.LogisticRegressionSuite - * val nPoints = 10000 - * val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) - * val xMean = Array(5.843, 3.057, 3.758, 1.199) - * val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - * val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( - * weights, xMean, xVariance, true, nPoints, 42), 1) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " - * + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + /* + Here is the instruction describing how to export the test data into CSV format + so we can validate the training accuracy compared with R's glmnet package. + + import org.apache.spark.mllib.classification.LogisticRegressionSuite + val nPoints = 10000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 1) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") */ binaryDataset = { val nPoints = 10000 @@ -211,22 +211,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(true) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 2.8366423 - * data.V2 -0.5895848 - * data.V3 0.8931147 - * data.V4 -0.3925051 - * data.V5 -0.7996864 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.8366423 + data.V2 -0.5895848 + data.V3 0.8931147 + data.V4 -0.3925051 + data.V5 -0.7996864 */ val interceptR = 2.8366423 val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864) @@ -242,23 +242,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(false) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = - * coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.3534996 - * data.V3 1.2964482 - * data.V4 -0.3571741 - * data.V5 -0.7407946 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = + coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.3534996 + data.V3 1.2964482 + data.V4 -0.3571741 + data.V5 -0.7407946 */ val interceptR = 0.0 val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946) @@ -275,22 +275,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.05627428 - * data.V2 . - * data.V3 . - * data.V4 -0.04325749 - * data.V5 -0.02481551 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.05627428 + data.V2 . + data.V3 . + data.V4 -0.04325749 + data.V5 -0.02481551 */ val interceptR = -0.05627428 val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551) @@ -307,23 +307,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 . - * data.V3 . - * data.V4 -0.05189203 - * data.V5 -0.03891782 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.05189203 + data.V5 -0.03891782 */ val interceptR = 0.0 val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782) @@ -340,22 +340,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.15021751 - * data.V2 -0.07251837 - * data.V3 0.10724191 - * data.V4 -0.04865309 - * data.V5 -0.10062872 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.15021751 + data.V2 -0.07251837 + data.V3 0.10724191 + data.V4 -0.04865309 + data.V5 -0.10062872 */ val interceptR = 0.15021751 val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872) @@ -372,23 +372,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.06099165 - * data.V3 0.12857058 - * data.V4 -0.04708770 - * data.V5 -0.09799775 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.06099165 + data.V3 0.12857058 + data.V4 -0.04708770 + data.V5 -0.09799775 */ val interceptR = 0.0 val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775) @@ -405,22 +405,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.57734851 - * data.V2 -0.05310287 - * data.V3 . - * data.V4 -0.08849250 - * data.V5 -0.15458796 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.57734851 + data.V2 -0.05310287 + data.V3 . + data.V4 -0.08849250 + data.V5 -0.15458796 */ val interceptR = 0.57734851 val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796) @@ -437,23 +437,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.001005743 - * data.V3 0.072577857 - * data.V4 -0.081203769 - * data.V5 -0.142534158 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.001005743 + data.V3 0.072577857 + data.V4 -0.081203769 + data.V5 -0.142534158 */ val interceptR = 0.0 val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158) @@ -480,16 +480,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { classSummarizer1.merge(classSummarizer2) }).histogram - /** - * For binary logistic regression with strong L1 regularization, all the weights will be zeros. - * As a result, - * {{{ - * P(0) = 1 / (1 + \exp(b)), and - * P(1) = \exp(b) / (1 + \exp(b)) - * }}}, hence - * {{{ - * b = \log{P(1) / P(0)} = \log{count_1 / count_0} - * }}} + /* + For binary logistic regression with strong L1 regularization, all the weights will be zeros. + As a result, + {{{ + P(0) = 1 / (1 + \exp(b)), and + P(1) = \exp(b) / (1 + \exp(b)) + }}}, hence + {{{ + b = \log{P(1) / P(0)} = \log{count_1 / count_0} + }}} */ val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) val weightsTheory = Array(0.0, 0.0, 0.0, 0.0) @@ -500,22 +500,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6) assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.2480643 - * data.V2 0.0000000 - * data.V3 . - * data.V4 . - * data.V5 . + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.2480643 + data.V2 0.0000000 + data.V3 . + data.V4 . + data.V5 . */ val interceptR = -0.248065 val weightsR = Array(0.0, 0.0, 0.0, 0.0) From 931da5c8ab271ff2ee04419c7e3c6b0012459694 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 29 Jun 2015 15:27:13 -0700 Subject: [PATCH 014/274] [SPARK-8478] [SQL] Harmonize UDF-related code to use uniformly UDF instead of Udf Follow-up of #6902 for being coherent between ```Udf``` and ```UDF``` Author: BenFradet Closes #6920 from BenFradet/SPARK-8478 and squashes the following commits: c500f29 [BenFradet] renamed a few variables in functions to use UDF 8ab0f2d [BenFradet] renamed idUdf to idUDF in SQLQuerySuite 98696c2 [BenFradet] renamed originalUdfs in TestHive to originalUDFs 7738f74 [BenFradet] modified HiveUDFSuite to use only UDF c52608d [BenFradet] renamed HiveUdfSuite to HiveUDFSuite e51b9ac [BenFradet] renamed ExtractPythonUdfs to ExtractPythonUDFs 8c756f1 [BenFradet] renamed Hive UDF related code 2a1ca76 [BenFradet] renamed pythonUdfs to pythonUDFs 261e6fb [BenFradet] renamed ScalaUdf to ScalaUDF --- .../{ScalaUdf.scala => ScalaUDF.scala} | 4 +- .../org/apache/spark/sql/SQLContext.scala | 4 +- .../apache/spark/sql/UDFRegistration.scala | 96 +++++++++--------- .../spark/sql/UserDefinedFunction.scala | 4 +- .../{pythonUdfs.scala => pythonUDFs.scala} | 2 +- .../org/apache/spark/sql/functions.scala | 34 +++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../hive/{hiveUdfs.scala => hiveUDFs.scala} | 26 ++--- .../apache/spark/sql/hive/test/TestHive.scala | 4 +- .../files/{testUdf => testUDF}/part-00000 | Bin ...{HiveUdfSuite.scala => HiveUDFSuite.scala} | 24 ++--- 13 files changed, 104 insertions(+), 104 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{ScalaUdf.scala => ScalaUDF.scala} (99%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{pythonUdfs.scala => pythonUDFs.scala} (99%) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/{hiveUdfs.scala => hiveUDFs.scala} (96%) rename sql/hive/src/test/resources/data/files/{testUdf => testUDF}/part-00000 (100%) rename sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/{HiveUdfSuite.scala => HiveUDFSuite.scala} (93%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala similarity index 99% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 55df72f102295..dbb4381d54c4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.DataType * User-defined function. * @param dataType Return type of function. */ -case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) +case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { override def nullable: Boolean = true @@ -957,6 +957,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) override def eval(input: InternalRow): Any = converter(f(input)) - // TODO(davies): make ScalaUdf work with codegen + // TODO(davies): make ScalaUDF work with codegen override def isThreadSafe: Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8ed44ee141be5..fc14a77538ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -146,7 +146,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - ExtractPythonUdfs :: + ExtractPythonUDFs :: sources.PreInsertCastAndRename :: Nil @@ -257,7 +257,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * The following example registers a Scala closure as UDF: * {{{ - * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) + * sqlContext.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1) * }}} * * The following example registers a UDF in Java: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 3cc5c2441d8a5..03dc37aa73f0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -95,7 +95,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) }""") @@ -114,7 +114,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType) = { | functionRegistry.registerFunction( | name, - | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), returnType, e)) + | (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e)) |}""".stripMargin) } */ @@ -126,7 +126,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -138,7 +138,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -150,7 +150,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -162,7 +162,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -174,7 +174,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -186,7 +186,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -210,7 +210,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -222,7 +222,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -234,7 +234,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -246,7 +246,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -258,7 +258,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -270,7 +270,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -282,7 +282,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -294,7 +294,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -306,7 +306,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -318,7 +318,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -330,7 +330,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -342,7 +342,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -366,7 +366,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -378,7 +378,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -390,7 +390,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -405,7 +405,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF1[_, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) } /** @@ -415,7 +415,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF2[_, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) } /** @@ -425,7 +425,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF3[_, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) } /** @@ -435,7 +435,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -445,7 +445,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -455,7 +455,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -465,7 +465,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -475,7 +475,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -485,7 +485,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -495,7 +495,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -505,7 +505,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -515,7 +515,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -525,7 +525,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -535,7 +535,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -545,7 +545,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -555,7 +555,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -565,7 +565,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -575,7 +575,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -585,7 +585,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -595,7 +595,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -605,7 +605,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -615,7 +615,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index a02e202d2eebc..831eb7eb0fae9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) { def apply(exprs: Column*): Column = { - Column(ScalaUdf(f, dataType, exprs.map(_.expr))) + Column(ScalaUDF(f, dataType, exprs.map(_.expr))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 036f5d253e385..9e1cff06c7eea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -69,7 +69,7 @@ private[spark] case class PythonUDF( * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { +private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5422e066afcb1..4d9a019058228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1509,7 +1509,7 @@ object functions { (0 to 10).map { x => val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") println(s""" /** * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires @@ -1521,7 +1521,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUdf(f, returnType, Seq($argsInUdf)) + ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } } @@ -1659,7 +1659,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = { - ScalaUdf(f, returnType, Seq()) + ScalaUDF(f, returnType, Seq()) } /** @@ -1672,7 +1672,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr)) } /** @@ -1685,7 +1685,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** @@ -1698,7 +1698,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** @@ -1711,7 +1711,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** @@ -1724,7 +1724,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** @@ -1737,7 +1737,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** @@ -1750,7 +1750,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** @@ -1763,7 +1763,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** @@ -1776,7 +1776,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** @@ -1789,7 +1789,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } // scalastyle:on @@ -1802,8 +1802,8 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUDF("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUDF("simpleUDF", $"value")) * }}} * * @group udf_funcs @@ -1821,8 +1821,8 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUDF", $"value")) * }}} * * @group udf_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 22c54e43c1d16..82dc0e9ce5132 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -140,9 +140,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUdf = udf(() => UUID.randomUUID().toString) + val idUDF = udf(() => UUID.randomUUID().toString) - val dfWithId = df.withColumn("id", idUdf()) + val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) val cached = dfWithId.cache() // Trigger the cache diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8021f915bb821..b91242af2d155 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.SQLConf.SQLConfEntry._ import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.sources.DataSourceStrategy @@ -381,7 +381,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUdfs :: + ExtractPythonUDFs :: ResolveHiveWindowFunction :: sources.PreInsertCastAndRename :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7c4620952ba4b..2de7a99c122fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1638,7 +1638,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - (HiveGenericUdtf( + (HiveGenericUDTF( new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)), attributes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala similarity index 96% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4986b1ea9d906..d7827d56ca8c5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -59,16 +59,16 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) val functionClassName = functionInfo.getFunctionClass.getName if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(new HiveFunctionWrapper(functionClassName), children) + HiveUDAF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } @@ -79,7 +79,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) throw new UnsupportedOperationException } -private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = UDF @@ -146,7 +146,7 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) override def get(): AnyRef = wrap(func(), oi) } -private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF @@ -413,7 +413,7 @@ private[hive] case class HiveWindowFunction( new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } -private[hive] case class HiveGenericUdaf( +private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -441,11 +441,11 @@ private[hive] case class HiveGenericUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUdaf( +private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -474,7 +474,7 @@ private[hive] case class HiveUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) } /** @@ -488,7 +488,7 @@ private[hive] case class HiveUdaf( * Operators that require maintaining state in between input rows should instead be implemented as * user defined aggregations, which have clean semantics even in a partitioned execution. */ -private[hive] case class HiveGenericUdtf( +private[hive] case class HiveGenericUDTF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors { @@ -553,7 +553,7 @@ private[hive] case class HiveGenericUdtf( } } -private[hive] case class HiveUdafFunction( +private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], base: AggregateExpression, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index ea325cc93cb85..7978fdacaedba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -391,7 +391,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * Records the UDFs present when the server starts, so we can delete ones that are created by * tests. */ - protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** * Resets the test instance by deleting any tables that have been created. @@ -410,7 +410,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => + FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUDF/part-00000 similarity index 100% rename from sql/hive/src/test/resources/data/files/testUdf/part-00000 rename to sql/hive/src/test/resources/data/files/testUDF/part-00000 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala similarity index 93% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index ce5985888f540..56b0bef1d0571 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -46,7 +46,7 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends QueryTest { +class HiveUDFSuite extends QueryTest { import TestHive.{udf, sql} import TestHive.implicits._ @@ -73,7 +73,7 @@ class HiveUdfSuite extends QueryTest { test("hive struct udf") { sql( """ - |CREATE EXTERNAL TABLE hiveUdfTestTable ( + |CREATE EXTERNAL TABLE hiveUDFTestTable ( | pair STRUCT |) |PARTITIONED BY (partition STRING) @@ -82,15 +82,15 @@ class HiveUdfSuite extends QueryTest { """. stripMargin.format(classOf[PairSerDe].getName)) - val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile + val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile sql(s""" - ALTER TABLE hiveUdfTestTable - ADD IF NOT EXISTS PARTITION(partition='testUdf') + ALTER TABLE hiveUDFTestTable + ADD IF NOT EXISTS PARTITION(partition='testUDF') LOCATION '$location'""") - sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") - sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") } test("SPARK-6409 UDAFAverage test") { @@ -169,11 +169,11 @@ class HiveUdfSuite extends QueryTest { StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") - sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), + sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") TestHive.reset() } @@ -244,7 +244,7 @@ class PairSerDe extends AbstractSerDe { } } -class PairUdf extends GenericUDF { +class PairUDF extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( Seq("id", "value"), From ed359de595d5dd67b666660eddf092eaf89041c8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Jun 2015 15:59:20 -0700 Subject: [PATCH 015/274] [SPARK-8579] [SQL] support arbitrary object in UnsafeRow This PR brings arbitrary object support in UnsafeRow (both in grouping key and aggregation buffer). Two object pools will be created to hold those non-primitive objects, and put the index of them into UnsafeRow. In order to compare the grouping key as bytes, the objects in key will be stored in a unique object pool, to make sure same objects will have same index (used as hashCode). For StringType and BinaryType, we still put them as var-length in UnsafeRow when initializing for better performance. But for update, they will be an object inside object pools (there will be some garbages left in the buffer). BTW: Will create a JIRA once issue.apache.org is available. cc JoshRosen rxin Author: Davies Liu Closes #6959 from davies/unsafe_obj and squashes the following commits: 5ce39da [Davies Liu] fix comment 5e797bf [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 5803d64 [Davies Liu] fix conflict 461d304 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 2f41c90 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj b04d69c [Davies Liu] address comments 4859b80 [Davies Liu] fix comments f38011c [Davies Liu] add a test for grouping by decimal d2cf7ab [Davies Liu] add more tests for null checking 71983c5 [Davies Liu] add test for timestamp e8a1649 [Davies Liu] reuse buffer for string 39f09ca [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 035501e [Davies Liu] fix style 236d6de [Davies Liu] support arbitrary object in UnsafeRow --- .../UnsafeFixedWidthAggregationMap.java | 144 ++++++------ .../sql/catalyst/expressions/UnsafeRow.java | 218 +++++++++--------- .../spark/sql/catalyst/util/ObjectPool.java | 78 +++++++ .../sql/catalyst/util/UniqueObjectPool.java | 59 +++++ .../spark/sql/catalyst/InternalRow.scala | 5 +- .../expressions/UnsafeRowConverter.scala | 94 +++----- .../UnsafeFixedWidthAggregationMapSuite.scala | 65 ++++-- .../expressions/UnsafeRowConverterSuite.scala | 190 +++++++++++---- .../sql/catalyst/util/ObjectPoolSuite.scala | 57 +++++ .../sql/execution/GeneratedAggregate.scala | 16 +- 10 files changed, 615 insertions(+), 311 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 83f2a312972fb..1e79f4b2e88e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -19,9 +19,11 @@ import java.util.Iterator; +import scala.Function1; + import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.sql.catalyst.util.UniqueObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -38,26 +40,48 @@ public final class UnsafeFixedWidthAggregationMap { * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the * map, we copy this buffer and use it as the value. */ - private final byte[] emptyAggregationBuffer; + private final byte[] emptyBuffer; - private final StructType aggregationBufferSchema; + /** + * An empty row used by `initProjection` + */ + private static final InternalRow emptyRow = new GenericInternalRow(); - private final StructType groupingKeySchema; + /** + * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not. + */ + private final boolean reuseEmptyBuffer; /** - * Encodes grouping keys as UnsafeRows. + * The projection used to initialize the emptyBuffer */ - private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + private final Function1 initProjection; + + /** + * Encodes grouping keys or buffers as UnsafeRows. + */ + private final UnsafeRowConverter keyConverter; + private final UnsafeRowConverter bufferConverter; /** * A hashmap which maps from opaque bytearray keys to bytearray values. */ private final BytesToBytesMap map; + /** + * An object pool for objects that are used in grouping keys. + */ + private final UniqueObjectPool keyPool; + + /** + * An object pool for objects that are used in aggregation buffers. + */ + private final ObjectPool bufferPool; + /** * Re-used pointer to the current aggregation buffer */ - private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + private final UnsafeRow currentBuffer = new UnsafeRow(); /** * Scratch space that is used when encoding grouping keys into UnsafeRow format. @@ -69,68 +93,39 @@ public final class UnsafeFixedWidthAggregationMap { private final boolean enablePerfMetrics; - /** - * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, - * false otherwise. - */ - public static boolean supportsGroupKeySchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - - /** - * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given - * schema, false otherwise. - */ - public static boolean supportsAggregationBufferSchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - /** * Create a new UnsafeFixedWidthAggregationMap. * - * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) - * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. - * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param initProjection the default value for new keys (a "zero" of the agg. function) + * @param keyConverter the converter of the grouping key, used for row conversion. + * @param bufferConverter the converter of the aggregation buffer, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( - InternalRow emptyAggregationBuffer, - StructType aggregationBufferSchema, - StructType groupingKeySchema, + Function1 initProjection, + UnsafeRowConverter keyConverter, + UnsafeRowConverter bufferConverter, TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { - this.emptyAggregationBuffer = - convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); - this.aggregationBufferSchema = aggregationBufferSchema; - this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); - this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.initProjection = initProjection; + this.keyConverter = keyConverter; + this.bufferConverter = bufferConverter; this.enablePerfMetrics = enablePerfMetrics; - } - /** - * Convert a Java object row into an UnsafeRow, allocating it into a new byte array. - */ - private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) { - final UnsafeRowConverter converter = new UnsafeRowConverter(schema); - final byte[] unsafeRow = new byte[converter.getSizeRequirement(javaRow)]; - final int writtenLength = - converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET); - assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; - return unsafeRow; + this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.keyPool = new UniqueObjectPool(100); + this.bufferPool = new ObjectPool(initialCapacity); + + InternalRow initRow = initProjection.apply(emptyRow); + this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; + int writtenLength = bufferConverter.writeRow( + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; + // re-use the empty buffer only when there is no object saved in pool. + reuseEmptyBuffer = bufferPool.size() == 0; } /** @@ -138,15 +133,16 @@ private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) * return the same object. */ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { - final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); + final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey); // Make sure that the buffer is large enough to hold the key. If it's not, grow it: if (groupingKeySize > groupingKeyConversionScratchSpace.length) { groupingKeyConversionScratchSpace = new byte[groupingKeySize]; } - final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( + final int actualGroupingKeySize = keyConverter.writeRow( groupingKey, groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET); + PlatformDependent.BYTE_ARRAY_OFFSET, + keyPool); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; // Probe our map using the serialized key @@ -157,25 +153,31 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: + if (!reuseEmptyBuffer) { + // There is some objects referenced by emptyBuffer, so generate a new one + InternalRow initRow = initProjection.apply(emptyRow); + bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, + bufferPool); + } loc.putNewKey( groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize, - emptyAggregationBuffer, + emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - emptyAggregationBuffer.length + emptyBuffer.length ); } // Reset the pointer to point to the value that we just stored or looked up: final MemoryLocation address = loc.getValueAddress(); - currentAggregationBuffer.pointTo( + currentBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); - return currentAggregationBuffer; + return currentBuffer; } /** @@ -211,14 +213,14 @@ public MapEntry next() { entry.key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - groupingKeySchema.length(), - groupingKeySchema + keyConverter.numFields(), + keyPool ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); return entry; } @@ -246,6 +248,8 @@ public void printPerfMetrics() { System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + System.out.println("Number of unique objects in keys: " + keyPool.size()); + System.out.println("Number of objects in buffers: " + bufferPool.size()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 11d51d90f1802..f077064a02ec0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,20 +17,12 @@ package org.apache.spark.sql.catalyst.expressions; -import javax.annotation.Nullable; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.*; /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -44,7 +36,20 @@ * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the * base address of the row) that points to the beginning of the variable-length field, and length - * (they are combined into a long). + * (they are combined into a long). For other objects, they are stored in a pool, the indexes of + * them are hold in the the word. + * + * In order to support fast hashing and equality checks for UnsafeRows that contain objects + * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make + * sure all the key have the same index for same object, then we can hash/compare the objects by + * hash/compare the index. + * + * For non-primitive types, the word of a field could be: + * UNION { + * [1] [offset: 31bits] [length: 31bits] // StringType + * [0] [offset: 31bits] [length: 31bits] // BinaryType + * - [index: 63bits] // StringType, Binary, index to object in pool + * } * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ @@ -53,8 +58,12 @@ public final class UnsafeRow extends MutableRow { private Object baseObject; private long baseOffset; + /** A pool to hold non-primitive objects */ + private ObjectPool pool; + Object getBaseObject() { return baseObject; } long getBaseOffset() { return baseOffset; } + ObjectPool getPool() { return pool; } /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; @@ -63,15 +72,6 @@ public final class UnsafeRow extends MutableRow { /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; - /** - * This optional schema is required if you want to call generic get() and set() methods on - * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE() - * methods. This should be removed after the planned InternalRow / Row split; right now, it's only - * needed by the generic get() method, which is only called internally by code that accesses - * UTF8String-typed columns. - */ - @Nullable - private StructType schema; private long getFieldOffset(int ordinal) { return baseOffset + bitSetWidthInBytes + ordinal * 8L; @@ -81,42 +81,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; } - /** - * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) - */ - public static final Set settableFieldTypes; - - /** - * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). - */ - public static final Set readableFieldTypes; - - // TODO: support DecimalType - static { - settableFieldTypes = Collections.unmodifiableSet( - new HashSet( - Arrays.asList(new DataType[] { - NullType, - BooleanType, - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType, - DateType, - TimestampType - }))); - - // We support get() on a superset of the types for which we support set(): - final Set _readableFieldTypes = new HashSet( - Arrays.asList(new DataType[]{ - StringType, - BinaryType - })); - _readableFieldTypes.addAll(settableFieldTypes); - readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); - } + public static final long OFFSET_BITS = 31L; /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, @@ -130,22 +95,15 @@ public UnsafeRow() { } * @param baseObject the base object * @param baseOffset the offset within the base object * @param numFields the number of fields in this row - * @param schema an optional schema; this is necessary if you want to call generic get() or set() - * methods on this row, but is optional if the caller will only use type-specific - * getTYPE() and setTYPE() methods. + * @param pool the object pool to hold arbitrary objects */ - public void pointTo( - Object baseObject, - long baseOffset, - int numFields, - @Nullable StructType schema) { + public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) { assert numFields >= 0 : "numFields should >= 0"; - assert schema == null || schema.fields().length == numFields; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; - this.schema = schema; + this.pool = pool; } private void assertIndexIsValid(int index) { @@ -168,9 +126,68 @@ private void setNotNullAt(int i) { BitSetMethods.unset(baseObject, baseOffset, i); } + /** + * Updates the column `i` as Object `value`, which cannot be primitive types. + */ @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); + public void update(int i, Object value) { + if (value == null) { + if (!isNullAt(i)) { + // remove the old value from pool + long idx = getLong(i); + if (idx <= 0) { + // this is the index of old value in pool, remove it + pool.replace((int)-idx, null); + } else { + // there will be some garbage left (UTF8String or byte[]) + } + setNullAt(i); + } + return; + } + + if (isNullAt(i)) { + // there is not an old value, put the new value into pool + int idx = pool.put(value); + setLong(i, (long)-idx); + } else { + // there is an old value, check the type, then replace it or update it + long v = getLong(i); + if (v <= 0) { + // it's the index in the pool, replace old value with new one + int idx = (int)-v; + pool.replace(idx, value); + } else { + // old value is UTF8String or byte[], try to reuse the space + boolean isString; + byte[] newBytes; + if (value instanceof UTF8String) { + newBytes = ((UTF8String) value).getBytes(); + isString = true; + } else { + newBytes = (byte[]) value; + isString = false; + } + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int oldLength = (int) (v & Integer.MAX_VALUE); + if (newBytes.length <= oldLength) { + // the new value can fit in the old buffer, re-use it + PlatformDependent.copyMemory( + newBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + offset, + newBytes.length); + long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L; + setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length); + } else { + // Cannot fit in the buffer + int idx = pool.put(value); + setLong(i, (long) -idx); + } + } + } + setNotNullAt(i); } @Override @@ -227,28 +244,38 @@ public int size() { return numFields; } - @Override - public StructType schema() { - return schema; - } - + /** + * Returns the object for column `i`, which should not be primitive type. + */ @Override public Object get(int i) { assertIndexIsValid(i); - assert (schema != null) : "Schema must be defined when calling generic get() method"; - final DataType dataType = schema.fields()[i].dataType(); - // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic - // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to - // separate the internal and external row interfaces, then internal code can fetch strings via - // a new getUTF8String() method and we'll be able to remove this method. if (isNullAt(i)) { return null; - } else if (dataType == StringType) { - return getUTF8String(i); - } else if (dataType == BinaryType) { - return getBinary(i); + } + long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + if (v <= 0) { + // It's an index to object in the pool. + int idx = (int)-v; + return pool.get(idx); } else { - throw new UnsupportedOperationException(); + // The column could be StingType or BinaryType + boolean isString = (v >> (OFFSET_BITS * 2)) > 0; + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int size = (int) (v & Integer.MAX_VALUE); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size + ); + if (isString) { + return UTF8String.fromBytes(bytes); + } else { + return bytes; + } } } @@ -308,31 +335,6 @@ public double getDouble(int i) { } } - public UTF8String getUTF8String(int i) { - return UTF8String.fromBytes(getBinary(i)); - } - - public byte[] getBinary(int i) { - assertIndexIsValid(i); - final long offsetAndSize = getLong(i); - final int offset = (int)(offsetAndSize >> 32); - final int size = (int)(offsetAndSize & ((1L << 32) - 1)); - final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size - ); - return bytes; - } - - @Override - public String getString(int i) { - return getUTF8String(i).toString(); - } - @Override public InternalRow copy() { throw new UnsupportedOperationException(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java new file mode 100644 index 0000000000000..97f89a7d0b758 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +/** + * A object pool stores a collection of objects in array, then they can be referenced by the + * pool plus an index. + */ +public class ObjectPool { + + /** + * An array to hold objects, which will grow as needed. + */ + private Object[] objects; + + /** + * How many objects in the pool. + */ + private int numObj; + + public ObjectPool(int capacity) { + objects = new Object[capacity]; + numObj = 0; + } + + /** + * Returns how many objects in the pool. + */ + public int size() { + return numObj; + } + + /** + * Returns the object at position `idx` in the array. + */ + public Object get(int idx) { + assert (idx < numObj); + return objects[idx]; + } + + /** + * Puts an object `obj` at the end of array, returns the index of it. + *

+ * The array will grow as needed. + */ + public int put(Object obj) { + if (numObj >= objects.length) { + Object[] tmp = new Object[objects.length * 2]; + System.arraycopy(objects, 0, tmp, 0, objects.length); + objects = tmp; + } + objects[numObj++] = obj; + return numObj - 1; + } + + /** + * Replaces the object at `idx` with new one `obj`. + */ + public void replace(int idx, Object obj) { + assert (idx < numObj); + objects[idx] = obj; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java new file mode 100644 index 0000000000000..d512392dcaacc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +import java.util.HashMap; + +/** + * An unique object pool stores a collection of unique objects in it. + */ +public class UniqueObjectPool extends ObjectPool { + + /** + * A hash map from objects to their indexes in the array. + */ + private HashMap objIndex; + + public UniqueObjectPool(int capacity) { + super(capacity); + objIndex = new HashMap(); + } + + /** + * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will + * return the index of the existing one. + */ + @Override + public int put(Object obj) { + if (objIndex.containsKey(obj)) { + return objIndex.get(obj); + } else { + int idx = super.put(obj); + objIndex.put(obj, idx); + return idx; + } + } + + /** + * The objects can not be replaced. + */ + @Override + public void replace(int idx, Object obj) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 61a29c89d8df3..57de0f26a9720 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -28,7 +28,10 @@ import org.apache.spark.unsafe.types.UTF8String abstract class InternalRow extends Row { // This is only use for test - override def getString(i: Int): String = getAs[UTF8String](i).toString + override def getString(i: Int): String = { + val str = getAs[UTF8String](i) + if (str != null) str.toString else null + } // These expensive API should not be used internally. final override def getDecimal(i: Int): java.math.BigDecimal = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index b61d490429e4f..b11fc245c4af9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -33,6 +34,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { this(schema.fields.map(_.dataType)) } + def numFields: Int = fieldTypes.length + /** Re-used pointer to the unsafe row being written */ private[this] val unsafeRow = new UnsafeRow() @@ -68,8 +71,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { * @param baseOffset the base offset of the destination address * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. */ - def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool) if (writers.length > 0) { // zero-out the bitset @@ -84,16 +87,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { } var fieldNumber = 0 - var appendCursor: Int = fixedLengthSize + var cursor: Int = fixedLengthSize while (fieldNumber < writers.length) { if (row.isNullAt(fieldNumber)) { unsafeRow.setNullAt(fieldNumber) } else { - appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) + cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor) } fieldNumber += 1 } - appendCursor + cursor } } @@ -108,11 +111,11 @@ private abstract class UnsafeColumnWriter { * @param source the row being converted * @param target a pointer to the converted unsafe row * @param column the column to write - * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * @param cursor the offset from the start of the unsafe row to the end of the row; * used for calculating where variable-length data should be written * @return the number of variable-length bytes written */ - def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int + def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int /** * Return the number of bytes that are needed to write this variable-length value. @@ -134,8 +137,7 @@ private object UnsafeColumnWriter { case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter case BinaryType => BinaryUnsafeColumnWriter - case t => - throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") + case t => ObjectUnsafeColumnWriter } } } @@ -152,6 +154,7 @@ private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter +private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: @@ -159,88 +162,56 @@ private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { } private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setNullAt(column) 0 } } private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setBoolean(column, source.getBoolean(column)) 0 } } private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setByte(column, source.getByte(column)) 0 } } private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setShort(column, source.getShort(column)) 0 } } private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setInt(column, source.getInt(column)) 0 } } private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setLong(column, source.getLong(column)) 0 } } private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setFloat(column, source.getFloat(column)) 0 } } private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setDouble(column, source.getDouble(column)) 0 } @@ -255,12 +226,10 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { - val offset = target.getBaseOffset + appendCursor + protected[this] def isString: Boolean + + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val offset = target.getBaseOffset + cursor val bytes = getBytes(source, column) val numBytes = bytes.length if ((numBytes & 0x07) > 0) { @@ -274,19 +243,32 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { offset, numBytes ) - target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong) + val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0 + target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong) ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } } private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = true def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[UTF8String](column).getBytes } } private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = false def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[Array[Byte]](column) } } + +private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(sourceRow: InternalRow, column: Int): Int = 0 + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val obj = source.get(column) + val idx = target.getPool.put(obj) + target.setLong(column, - idx) + 0 + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 3095ccb77761b..6fafc2f86684c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,8 +23,9 @@ import scala.util.Random import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -33,10 +34,10 @@ class UnsafeFixedWidthAggregationMapSuite with Matchers with BeforeAndAfterEach { - import UnsafeFixedWidthAggregationMap._ - private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyProjection: Projection = + GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)())) private def emptyAggregationBuffer: InternalRow = InternalRow(0) private var memoryManager: TaskMemoryManager = null @@ -52,21 +53,11 @@ class UnsafeFixedWidthAggregationMapSuite } } - test("supported schemas") { - assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) - assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) - - assert( - !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - assert( - !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - } - test("empty map") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics @@ -77,9 +68,9 @@ class UnsafeFixedWidthAggregationMapSuite test("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics @@ -103,9 +94,9 @@ class UnsafeFixedWidthAggregationMapSuite test("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 128, // initial capacity false // disable perf metrics @@ -120,6 +111,36 @@ class UnsafeFixedWidthAggregationMapSuite }.toSet seenKeys.size should be (groupKeys.size) seenKeys should be (groupKeys) + + map.free() + } + + test("with decimal in the key and values") { + val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil) + val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil) + val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))), + Seq(AttributeReference("price", DecimalType.Unlimited)())) + val map = new UnsafeFixedWidthAggregationMap( + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), + memoryManager, + 1, // initial capacity + false // disable perf metrics + ) + + (0 until 100).foreach { i => + val groupKey = InternalRow(Decimal(i % 10)) + val row = map.getAggregationBuffer(groupKey) + row.update(0, Decimal(i)) + } + val seenKeys: Set[Int] = map.iterator().asScala.map { entry => + entry.key.getAs[Decimal](0).toInt + }.toSet + seenKeys.size should be (10) + seenKeys should be ((0 until 10).toSet) + + map.free() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c0675f4f4dff6..94c2f3242b122 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -23,10 +23,11 @@ import java.util.Arrays import org.scalatest.Matchers import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.UTF8String class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { @@ -40,16 +41,21 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.setInt(2, 2) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (3 * 8)) + assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getLong(1) should be (1) - unsafeRow.getInt(2) should be (2) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getLong(1) === 1) + assert(unsafeRow.getInt(2) === 2) + + unsafeRow.setLong(1, 3) + assert(unsafeRow.getLong(1) === 3) + unsafeRow.setInt(2, 4) + assert(unsafeRow.getInt(2) === 4) } test("basic conversion with primitive, string and binary types") { @@ -58,22 +64,67 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) - row.setString(1, "Hello") + row.update(1, UTF8String.fromString("Hello")) row.update(2, "World".getBytes) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (8 * 3) + + assert(sizeRequired === 8 + (8 * 3) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getString(1) should be ("Hello") - unsafeRow.getBinary(2) should be ("World".getBytes) + val pool = new ObjectPool(10) + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") + assert(unsafeRow.get(2) === "World".getBytes) + + unsafeRow.update(1, UTF8String.fromString("World")) + assert(unsafeRow.getString(1) === "World") + assert(pool.size === 0) + unsafeRow.update(1, UTF8String.fromString("Hello World")) + assert(unsafeRow.getString(1) === "Hello World") + assert(pool.size === 1) + + unsafeRow.update(2, "World".getBytes) + assert(unsafeRow.get(2) === "World".getBytes) + assert(pool.size === 1) + unsafeRow.update(2, "Hello World".getBytes) + assert(unsafeRow.get(2) === "Hello World".getBytes) + assert(pool.size === 2) + } + + test("basic conversion with primitive, decimal and array") { + val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType)) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.update(1, Decimal(1)) + row.update(2, Array(2)) + + val pool = new ObjectPool(10) + val sizeRequired: Int = converter.getSizeRequirement(row) + assert(sizeRequired === 8 + (8 * 3)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + assert(numBytesWritten === sizeRequired) + assert(pool.size === 2) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.get(1) === Decimal(1)) + assert(unsafeRow.get(2) === Array(2)) + + unsafeRow.update(1, Decimal(2)) + assert(unsafeRow.get(1) === Decimal(2)) + unsafeRow.update(2, Array(3, 4)) + assert(unsafeRow.get(2) === Array(3, 4)) + assert(pool.size === 2) } test("basic conversion with primitive, string, date and timestamp types") { @@ -87,21 +138,27 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (8 * 4) + + assert(sizeRequired === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getString(1) should be ("Hello") + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow - DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01")) + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01")) // Timestamp is represented as Long in unsafeRow DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be (Timestamp.valueOf("2015-05-08 08:10:25")) + + unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22"))) + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22")) + unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25"))) + DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be + (Timestamp.valueOf("2015-06-22 08:10:25")) } test("null handling") { @@ -113,7 +170,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { IntegerType, LongType, FloatType, - DoubleType) + DoubleType, + StringType, + BinaryType, + DecimalType.Unlimited, + ArrayType(IntegerType) + ) val converter = new UnsafeRowConverter(fieldTypes) val rowWithAllNullColumns: InternalRow = { @@ -127,8 +189,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( @@ -136,13 +198,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { for (i <- 0 to fieldTypes.length - 1) { assert(createdFromNull.isNullAt(i)) } - createdFromNull.getBoolean(1) should be (false) - createdFromNull.getByte(2) should be (0) - createdFromNull.getShort(3) should be (0) - createdFromNull.getInt(4) should be (0) - createdFromNull.getLong(5) should be (0) + assert(createdFromNull.getBoolean(1) === false) + assert(createdFromNull.getByte(2) === 0) + assert(createdFromNull.getShort(3) === 0) + assert(createdFromNull.getInt(4) === 0) + assert(createdFromNull.getLong(5) === 0) assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getFloat(7))) + assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) + assert(createdFromNull.getString(8) === null) + assert(createdFromNull.get(9) === null) + assert(createdFromNull.get(10) === null) + assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those // columns, then the serialized row representation should be identical to what we would get by @@ -157,28 +223,68 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setLong(5, 500) r.setFloat(6, 600) r.setDouble(7, 700) + r.update(8, UTF8String.fromString("hello")) + r.update(9, "world".getBytes) + r.update(10, Decimal(10)) + r.update(11, Array(11)) r } - val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val pool = new ObjectPool(1) + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) - setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0)) - setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1)) - setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2)) - setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3)) - setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4)) - setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5)) - setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6)) - setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- 0 to fieldTypes.length - 1) { + if (i >= 8) { + setToNullAfterCreation.update(i, null) + } setToNullAfterCreation.setNullAt(i) } - assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + // There are some garbage left in the var-length area + assert(Arrays.equals(createdFromNullBuffer, + java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8))) + + setToNullAfterCreation.setNullAt(0) + setToNullAfterCreation.setBoolean(1, false) + setToNullAfterCreation.setByte(2, 20) + setToNullAfterCreation.setShort(3, 30) + setToNullAfterCreation.setInt(4, 400) + setToNullAfterCreation.setLong(5, 500) + setToNullAfterCreation.setFloat(6, 600) + setToNullAfterCreation.setDouble(7, 700) + setToNullAfterCreation.update(8, UTF8String.fromString("hello")) + setToNullAfterCreation.update(9, "world".getBytes) + setToNullAfterCreation.update(10, Decimal(10)) + setToNullAfterCreation.update(11, Array(11)) + + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala new file mode 100644 index 0000000000000..94764df4b9cdb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class ObjectPoolSuite extends SparkFunSuite with Matchers { + + test("pool") { + val pool = new ObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(false) === 2) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.get(2) === false) + assert(pool.size() === 3) + + pool.replace(1, "world") + assert(pool.get(1) === "world") + assert(pool.size() === 3) + } + + test("unique pool") { + val pool = new UniqueObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.size() === 2) + + intercept[UnsupportedOperationException] { + pool.replace(1, "world") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index ba2c8f53d702d..44930f82b53a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -238,11 +238,6 @@ case class GeneratedAggregate( StructType(fields) } - val schemaSupportsUnsafe: Boolean = { - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) - } - child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -283,12 +278,12 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) - } else if (unsafeEnabled && schemaSupportsUnsafe) { + } else if (unsafeEnabled) { log.info("Using Unsafe-based aggregator") val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, + newAggregationBuffer, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggregationBufferSchema), TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity false // disable tracking of performance metrics @@ -323,9 +318,6 @@ case class GeneratedAggregate( } } } else { - if (unsafeEnabled) { - log.info("Not using Unsafe-based aggregator because it is not supported for this schema") - } val buffers = new java.util.HashMap[InternalRow, MutableRow]() var currentRow: InternalRow = null From 4e880cf5967c0933e1d098a1d1f7db34b23ca8f8 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Mon, 29 Jun 2015 16:09:29 -0700 Subject: [PATCH 016/274] [SPARK-8661][ML] for LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments, to make copy-pasting R code more simple for mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments, to make copy-pasting R code more simple Author: Rosstin Closes #7098 from Rosstin/SPARK-8661 and squashes the following commits: 5a05dee [Rosstin] SPARK-8661 for LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments to make it easier to copy-paste the R code. bb9a4b1 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8660 242aedd [Rosstin] SPARK-8660, changed comment style from JavaDoc style to normal multiline comment in order to make copypaste into R easier, in file classification/LogisticRegressionSuite.scala 2cd2985 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 21ac1e5 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 6c18058 [Rosstin] fixed minor typos in docs/README.md and docs/api.md --- .../ml/regression/LinearRegressionSuite.scala | 192 +++++++++--------- 1 file changed, 96 insertions(+), 96 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index ad1e9da692ee2..5f39d44f37352 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -28,26 +28,26 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @transient var datasetWithoutIntercept: DataFrame = _ - /** - * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML - * is the same as the one trained by R's glmnet package. The following instruction - * describes how to reproduce the data in R. - * - * import org.apache.spark.mllib.util.LinearDataGenerator - * val data = - * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), - * Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) - * .saveAsTextFile("path") + /* + In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + is the same as the one trained by R's glmnet package. The following instruction + describes how to reproduce the data in R. + + import org.apache.spark.mllib.util.LinearDataGenerator + val data = + sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), + Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) + .saveAsTextFile("path") */ override def beforeAll(): Unit = { super.beforeAll() dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) - /** - * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating - * training model without intercept + /* + datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating + training model without intercept */ datasetWithoutIntercept = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( @@ -59,20 +59,20 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = new LinearRegression val model = trainer.fit(dataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * library("glmnet") - * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) - * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) - * label <- as.numeric(data$V1) - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.300528 - * as.numeric.data.V2. 4.701024 - * as.numeric.data.V3. 7.198257 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + label <- as.numeric(data$V1) + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.300528 + as.numeric.data.V2. 4.701024 + as.numeric.data.V3. 7.198257 */ val interceptR = 6.298698 val weightsR = Array(4.700706, 7.199082) @@ -94,29 +94,29 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model = trainer.fit(dataset) val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, - * intercept = FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data.V2. 6.995908 - * as.numeric.data.V3. 5.275131 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.995908 + as.numeric.data.V3. 5.275131 */ val weightsR = Array(6.995908, 5.275131) assert(model.intercept ~== 0 relTol 1E-3) assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights(1) ~== weightsR(1) relTol 1E-3) - /** - * Then again with the data with no intercept: - * > weightsWithoutIntercept - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data3.V2. 4.70011 - * as.numeric.data3.V3. 7.19943 + /* + Then again with the data with no intercept: + > weightsWithoutIntercept + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data3.V2. 4.70011 + as.numeric.data3.V3. 7.19943 */ val weightsWithoutInterceptR = Array(4.70011, 7.19943) @@ -129,14 +129,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.24300 - * as.numeric.data.V2. 4.024821 - * as.numeric.data.V3. 6.679841 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.24300 + as.numeric.data.V2. 4.024821 + as.numeric.data.V3. 6.679841 */ val interceptR = 6.24300 val weightsR = Array(4.024821, 6.679841) @@ -158,15 +158,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setFitIntercept(false) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - * intercept=FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data.V2. 6.299752 - * as.numeric.data.V3. 4.772913 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.299752 + as.numeric.data.V3. 4.772913 */ val interceptR = 0.0 val weightsR = Array(6.299752, 4.772913) @@ -187,14 +187,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.328062 - * as.numeric.data.V2. 3.222034 - * as.numeric.data.V3. 4.926260 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.328062 + as.numeric.data.V2. 3.222034 + as.numeric.data.V3. 4.926260 */ val interceptR = 5.269376 val weightsR = Array(3.736216, 5.712356) @@ -216,15 +216,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setFitIntercept(false) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, - * intercept = FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data.V2. 5.522875 - * as.numeric.data.V3. 4.214502 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.522875 + as.numeric.data.V3. 4.214502 */ val interceptR = 0.0 val weightsR = Array(5.522875, 4.214502) @@ -245,14 +245,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.324108 - * as.numeric.data.V2. 3.168435 - * as.numeric.data.V3. 5.200403 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.324108 + as.numeric.data.V2. 3.168435 + as.numeric.data.V3. 5.200403 */ val interceptR = 5.696056 val weightsR = Array(3.670489, 6.001122) @@ -274,15 +274,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setFitIntercept(false) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, - * intercept=FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.dataM.V2. 5.673348 - * as.numeric.dataM.V3. 4.322251 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.dataM.V2. 5.673348 + as.numeric.dataM.V3. 4.322251 */ val interceptR = 0.0 val weightsR = Array(5.673348, 4.322251) From 4b497a724a87ef24702c2df9ec6863ee57a87c1c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 29 Jun 2015 16:26:05 -0700 Subject: [PATCH 017/274] [SPARK-8710] [SQL] Change ScalaReflection.mirror from a val to a def. jira: https://issues.apache.org/jira/browse/SPARK-8710 Author: Yin Huai Closes #7094 from yhuai/SPARK-8710 and squashes the following commits: c854baa [Yin Huai] Change ScalaReflection.mirror from a val to a def. --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 90698cd572de4..21b1de1ab9cb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -28,7 +28,11 @@ import org.apache.spark.sql.types._ */ object ScalaReflection extends ScalaReflection { val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - val mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) + // Since we are creating a runtime mirror usign the class loader of current thread, + // we need to use def at here. So, every time we call mirror, it is using the + // class loader of the current thread. + override def mirror: universe.Mirror = + universe.runtimeMirror(Thread.currentThread().getContextClassLoader) } /** @@ -39,7 +43,7 @@ trait ScalaReflection { val universe: scala.reflect.api.Universe /** The mirror used to access types in the universe */ - val mirror: universe.Mirror + def mirror: universe.Mirror import universe._ From 881662e9c93893430756320f51cef0fc6643f681 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 29 Jun 2015 16:34:50 -0700 Subject: [PATCH 018/274] [SPARK-8589] [SQL] cleanup DateTimeUtils move date time related operations into `DateTimeUtils` and rename some methods to make it more clear. Author: Wenchen Fan Closes #6980 from cloud-fan/datetime and squashes the following commits: 9373a9d [Wenchen Fan] cleanup DateTimeUtil --- .../spark/sql/catalyst/expressions/Cast.scala | 43 ++---------- .../sql/catalyst/util/DateTimeUtils.scala | 70 +++++++++++++------ .../spark/sql/hive/hiveWriterContainers.scala | 2 +- 3 files changed, 58 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8d66968a2fc35..d69d490ad666a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} import java.sql.{Date, Timestamp} -import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -122,9 +121,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) - case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.toString(d))) + case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, - t => UTF8String.fromString(timestampToString(DateTimeUtils.toJavaTimestamp(t)))) + t => UTF8String.fromString(DateTimeUtils.timestampToString(t))) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -183,7 +182,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => DateTimeUtils.toMillisSinceEpoch(d) * 10000) + buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 10000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -216,18 +215,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w ts / 10000000.0 } - // Converts Timestamp to string according to Hive TimestampWritable convention - private[this] def timestampToString(ts: Timestamp): String = { - val timestampString = ts.toString - val formatted = Cast.threadLocalTimestampFormat.get.format(ts) - - if (timestampString.length > 19 && timestampString.substring(19) != ".0") { - formatted + timestampString.substring(19) - } else { - formatted - } - } - // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => @@ -449,11 +436,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (DateType, StringType) => defineCodeGen(ctx, ev, c => s"""${ctx.stringType}.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.toString($c))""") - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") case (TimestampType, StringType) => - super.genCode(ctx, ev) + defineCodeGen(ctx, ev, c => + s"""${ctx.stringType}.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") case (_, StringType) => defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") @@ -477,19 +464,3 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } } - -object Cast { - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - } - } - - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd") - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index ff79884a44d00..640e67e2ecd76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat +import java.text.{DateFormat, SimpleDateFormat} import java.util.{Calendar, TimeZone} -import org.apache.spark.sql.catalyst.expressions.Cast - /** * Helper functions for converting between internal and external date and time representations. * Dates are exposed externally as java.sql.Date and are represented internally as the number of @@ -41,35 +39,53 @@ object DateTimeUtils { // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. - private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { + private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { override protected def initialValue: TimeZone = { Calendar.getInstance.getTimeZone } } - private def javaDateToDays(d: Date): Int = { - millisToDays(d.getTime) + // `SimpleDateFormat` is not thread-safe. + private val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + } + } + + // `SimpleDateFormat` is not thread-safe. + private val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd") + } } + // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { - ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt + ((millisLocal + threadLocalLocalTimeZone.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } - def toMillisSinceEpoch(days: Int): Long = { + // reverse of millisToDays + def daysToMillis(days: Int): Long = { val millisUtc = days.toLong * MILLIS_PER_DAY - millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) + millisUtc - threadLocalLocalTimeZone.get().getOffset(millisUtc) } - def fromJavaDate(date: Date): Int = { - javaDateToDays(date) - } + def dateToString(days: Int): String = + threadLocalDateFormat.get.format(toJavaDate(days)) - def toJavaDate(daysSinceEpoch: Int): Date = { - new Date(toMillisSinceEpoch(daysSinceEpoch)) - } + // Converts Timestamp to string according to Hive TimestampWritable convention. + def timestampToString(num100ns: Long): String = { + val ts = toJavaTimestamp(num100ns) + val timestampString = ts.toString + val formatted = threadLocalTimestampFormat.get.format(ts) - def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) + if (timestampString.length > 19 && timestampString.substring(19) != ".0") { + formatted + timestampString.substring(19) + } else { + formatted + } + } def stringToTime(s: String): java.util.Date = { if (!s.contains('T')) { @@ -100,7 +116,21 @@ object DateTimeUtils { } /** - * Return a java.sql.Timestamp from number of 100ns since epoch + * Returns the number of days since epoch from from java.sql.Date. + */ + def fromJavaDate(date: Date): Int = { + millisToDays(date.getTime) + } + + /** + * Returns a java.sql.Date from number of days since epoch. + */ + def toJavaDate(daysSinceEpoch: Int): Date = { + new Date(daysToMillis(daysSinceEpoch)) + } + + /** + * Returns a java.sql.Timestamp from number of 100ns since epoch. */ def toJavaTimestamp(num100ns: Long): Timestamp = { // setNanos() will overwrite the millisecond part, so the milliseconds should be @@ -118,7 +148,7 @@ object DateTimeUtils { } /** - * Return the number of 100ns since epoch from java.sql.Timestamp. + * Returns the number of 100ns since epoch from java.sql.Timestamp. */ def fromJavaTimestamp(t: Timestamp): Long = { if (t != null) { @@ -129,7 +159,7 @@ object DateTimeUtils { } /** - * Return the number of 100ns (hundred of nanoseconds) since epoch from Julian day + * Returns the number of 100ns (hundred of nanoseconds) since epoch from Julian day * and nanoseconds in a day */ def fromJulianDay(day: Int, nanoseconds: Long): Long = { @@ -139,7 +169,7 @@ object DateTimeUtils { } /** - * Return Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) + * Returns Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) */ def toJulianDay(num100ns: Long): (Int, Long) = { val seconds = num100ns / HUNDRED_NANOS_PER_SECOND + SECONDS_PER_DAY / 2 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ab75b12e2a2e7..ecc78a5f8d321 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -201,7 +201,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { - case DateType => DateTimeUtils.toString(raw.toInt) + case DateType => DateTimeUtils.dateToString(raw.toInt) case _: DecimalType => BigDecimal(raw).toString() case _ => raw } From cec98525fd2b731cb78935bf7bc6c7963411744e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 29 Jun 2015 17:19:05 -0700 Subject: [PATCH 019/274] [SPARK-8634] [STREAMING] [TESTS] Fix flaky test StreamingListenerSuite "receiver info reporting" As per the unit test log in https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/35754/ ``` 15/06/24 23:09:10.210 Thread-3495 INFO ReceiverTracker: Starting 1 receivers 15/06/24 23:09:10.270 Thread-3495 INFO SparkContext: Starting job: apply at Transformer.scala:22 ... 15/06/24 23:09:14.259 ForkJoinPool-4-worker-29 INFO StreamingListenerSuiteReceiver: Started receiver and sleeping 15/06/24 23:09:14.270 ForkJoinPool-4-worker-29 INFO StreamingListenerSuiteReceiver: Reporting error and sleeping ``` it needs at least 4 seconds to receive all receiver events in this slow machine, but `timeout` for `eventually` is only 2 seconds. This PR increases `timeout` to make this test stable. Author: zsxwing Closes #7017 from zsxwing/SPARK-8634 and squashes the following commits: 719cae4 [zsxwing] Fix flaky test StreamingListenerSuite "receiver info reporting" --- .../org/apache/spark/streaming/StreamingListenerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 1dc8960d60528..7bc7727a9fbe4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -116,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { - eventually(timeout(2000 millis), interval(20 millis)) { + eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds(0) should equal (0) collector.stoppedReceiverStreamIds should have size 1 From fbf75738feddebb352d5cedf503b573105d4b7a7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 29 Jun 2015 17:20:05 -0700 Subject: [PATCH 020/274] [SPARK-7287] [SPARK-8567] [TEST] Add sc.stop to applications in SparkSubmitSuite Hopefully, this suite will not be flaky anymore. Author: Yin Huai Closes #7027 from yhuai/SPARK-8567 and squashes the following commits: c0167e2 [Yin Huai] Add sc.stop(). --- .../spark/deploy/SparkSubmitSuite.scala | 2 ++ .../regression-test-SPARK-8489/Main.scala | 1 + .../regression-test-SPARK-8489/test.jar | Bin 6811 -> 6828 bytes 3 files changed, 3 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 357ed90be3f5c..2e05dec99b6bf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -548,6 +548,7 @@ object JarCreationTest extends Logging { if (result.nonEmpty) { throw new Exception("Could not load user class from jar:\n" + result(0)) } + sc.stop() } } @@ -573,6 +574,7 @@ object SimpleApplicationTest { s"Master had $config=$masterValue but executor had $config=$executorValue") } } + sc.stop() } } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala index e1715177e3f1b..0e428ba1d7456 100644 --- a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -38,6 +38,7 @@ object Main { val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) df.collect() println("Regression test for SPARK-8489 success!") + sc.stop() } } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar index 4f59fba9eab558131b2587e51b7c2e2d54348bd1..5944aa6076a5fe7a8188c947fd6847c046614101 100644 GIT binary patch delta 1819 zcmY+Fdpy&N8^^!q&d6nHbB(cs!yGh4h#1w-VbXpXd4G^E~h8`RDoEK~QJrZcO7ZR4bX zovjUayHIFvIB4u~enzUU90>C?Xn`=q87-R_p_>v%U9o!iPEZU_oi;@Ykr^AW3Krfy zmcQC-)r+LqTpB$vO1<1}<6pm(VB^$y5NwOaqD4bQ%lmWmu_TxnX~D2i%w^e2>i zBOIKNgW?}+>8K5g4A!-qxZ$Kfc4x3WT+su)P9o8Af=%SWJ_dAhA>7`Spxwjw*L=tp z9;xt`sDXKE+gBR&q{>abkMI;lqA`#0&u&#K?nG~~2SL~xC$!ZoPG6>oxnO9b)VQi& zmifq4VngbH11qpDt`&zCx4@IG=F}QAi@I$SGpebd;+yvDk>Gqh^^Zr+q;E)Sv3G7< zko3JoJ{R~{d!D?tt*=?5 zjdP+JG;wKn??*dme5_x3F0Y#1zXD}L1jCYS(hO=-59%g zE|DI0pw6uQ+#7g=PC($_RmLU#d7X8H6M-h`If3=9>EzsO$7|4z{f~_9nBA9YHb{J_ zH;AnB{gkgED;+oHa?^f-XYFd9ih)l75ME~Vb-nkE0*gTj1gZr zmI6P`fni3aj^2urMdMXv_OVeSclGogc*Sfixun8dB%B@e;ocEqVN^kr(Y(%Hqgnr0 zO77z(l|MsGo-D2lwGQ^jFA#3aoyo0EBDQu)UgFF?DjR*9h3jz>wN@(iL8Gzm6Oo7%tZq;#+T# zrP|J!?{#grxk!)JZ|XD1&30Q;o4Hk|coLU}8p@aHUu;=Q2vEsne)5NF#5xd;WU8r* zE_dC^XVylao(UYVedTTOfTzCrYUpCr(Ul>0JJD_+{ez$OrfbrE(7)1Nf*`%iTeuQu zr{n6W=vbkV1j4V&+uAwQTmkkVIc1x&mcQ0~+&4k=6+{DP?lvD~= z@1okIWs-QA$NIkLJuau}Bv&O@r8V-p=6b-y$Qa$k1=B<3NjP)MY5d()&j{$#xgi-p z{bzWbyOr=KvA!kSVKDOK!h8n=xwz&c%Z_pzb6s2;JRZzDU;hA=c#t8zcQa#|Olh5q zVZB)#=fxz3HO!k#_g>DLI)x-@xWi~ZZQa%#iKTs_4zJ;)p`hQ;(85_<-|9~x)$3TU zX>^MHgbm**D`=hw8Q+wvO#gNBs0(kB)gAo~Z#$v=J2ss*#38DZ%2Lr$lZ@eyB zE(}E8e5hucwbeJ=UGdL@56FhwjB1aqh^3>;5-rH8%=+aRp7a2C9d;?&ljUHw^j7gg>qW}FNHfIq(Cen2?@--XoI+Q}z44Cx!{ z6Io{?A_I~U0zrWP+Y3zzW+(sEP8@2=8Ob}60B}nI0OV;QFjd+POyQf+c)?Ua@1-(E zAhZ-17{9Ci{z~cZU;yAI`yWpG8!r-z`2pte{m#liG%7-9zbNZ|43&mfq}Y~f zle*&8MHGq2V6AJcy6P&^K~<}moi4lk+4Y&{pZ9q`@9%w|_pk4lYr-;-cEy9l6aZi_ z7{Dcu-ji+s9WJ~tQvM(i=dk&phz6cI7O3wrm-=@?O(7CUZB1+m%(^(&nn+=G#?4k@ zcsnp1qW<-275~XT&qCw)0FRSV6c-*xIeyPr5X`poae?~!foeYw7YFQx1tkVy>IQaO zwwtH`SG)w+Kt$-x1p)xk;(){bv6)!Fu{w3wJNc6$HS`0+ss1rBY_HuIC=Q;gQ_L8W zl67<90px-N<#3Q;vwL@JkXjMVR7*X0PZw3X*}E2WA#-O|hxFv`4)M*;LOaWW&#S+? zJ$z{;X<0PX104+6jo&@!OeIdW91M3+|B&{xxDBoTbsVBoWdqhVdzu#}>CU^DvuL-v z)5s~u+Ax9C>06}1F|qrziwDkHyMoWm9fc0D^cm7<9c0&*RzmrpP1}!*f@jeS=m4Zx zsNkG#-5Ki#r864K=2^_O!t)7v)T=1sG5+TsXSKy#DFt+le1lIxL#S>vO4Bc(40S+WxssX?ShFq1rFS^&nB0 z-R9E+8jai{)p2MWccs+jNzioQwVK;G@2?J~J;ad3>x+t(A7#q*Cec26dak*9sFiT_ zm(1D>Go^pgsMEUXFbV$&&K~ob$uX4mo{bIs4Jdju$gF?_8f~~e>#Tgff12j$VbK3| z1WStw;>&6g_Qgv=us;}+qdcJ&tA-<-8>9f8=5L`n&$Z(3jCkFj){oZoT^mflnZn)@ zEML+&Wz%;W7H0Dl^PR@I1ioV~axv+RENl8sbXRvj2@F* z(8ss@z%ujdw5-`2)oDz!qh7Il9nPjve1!TcGFkRus!4+E>bLs^5-7Ls4@)Lr_FMB? zHghXb025;y!^^B21ms3Z+D2=Bxrv0;XrVm+zO zQinBNVMvM;VbLA-Ejx^64A%mf$MmA(8=3rSx*QxPckXWWyyZaa(yiJBQ;gnj@2a6? z*(>`#pYTsCc)y=XccxxlJC(qpp5g3%Xk(ChE=M!3`iJEf0#&zOn(nZ^k{?V8=qhP5SjoAQkmwvS@NTwg`VPB*`Yaf+) zkzvCZzNivX%T_`qlp}HnE*TA*eV@I&;xJLDkT7lDa{Kf!{_jbA>L!78mNTze#Su$y z34WRQx4ERU)O+znKIgn|f7mv#Hi?Xj*^4a^DWy{~@+A_+<_({XmpazumOAWc$_n3? z8_&=Rv@s6#dVj!0-}=%8+ssT?GPm;e0~=9~3&TkQ`8e&dRg} z;LyxvOxn=kTYeQ$5_26g{vRI-jx-u>J6s4ZmT;H5F>4FcBl4oV2y7R+jV77~T zwZ+7!MXE~s7#&afYSS$g9sV+32A=P4-Dcg3t6uazTeA>|T}OhpB&~0}yZx6>p+iE9 zFu_@8g3yv9@8+|5prb}yuMBDC6~BixIWF*uE=8v*y0*5u0yXj`l8rdv`Cymlr@QLr zgnG4wxn1+?-@S|ZnEp;G#c3w$y;WrN@0^{tmrzI8#=SgKaO)7+*^nb_p8zTV{$~>g z8ew-N3kDaV#K=a-FeVVtBdahXkV+~DvHyVlkmH9K|MSfVgc1;c^4kdW%Qt_k&;^^dG1;{1+ From 5d30eae56051c563a8427f330b09ef66db0a0d21 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 29 Jun 2015 17:21:35 -0700 Subject: [PATCH 021/274] [SPARK-8437] [DOCS] Using directory path without wildcard for filename slow for large number of files with wholeTextFiles and binaryFiles Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' Author: Sean Owen Closes #7036 from srowen/SPARK-8437 and squashes the following commits: 0e813ae [Sean Owen] Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' --- core/src/main/scala/org/apache/spark/SparkContext.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b3c3bf3746e18..cb7e24c374152 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -831,6 +831,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory + * rather than `.../path/` or `.../path` * * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @@ -878,9 +880,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory + * rather than `.../path/` or `.../path` + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental def binaryFiles( From d7f796da45d9a7c76ee4c29a9e0661ef76d8028a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 17:27:02 -0700 Subject: [PATCH 022/274] [SPARK-8410] [SPARK-8475] remove previous ivy resolution when using spark-submit This PR also includes re-ordering the order that repositories are used when resolving packages. User provided repositories will be prioritized. cc andrewor14 Author: Burak Yavuz Closes #7089 from brkyvz/delete-prev-ivy-resolution and squashes the following commits: a21f95a [Burak Yavuz] remove previous ivy resolution when using spark-submit --- .../org/apache/spark/deploy/SparkSubmit.scala | 37 ++++++++++++------- .../spark/deploy/SparkSubmitUtilsSuite.scala | 6 +-- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index abf222757a95b..b1d6ec209d62b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -756,6 +756,20 @@ private[spark] object SparkSubmitUtils { val cr = new ChainResolver cr.setName("list") + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + } + } + val localM2 = new IBiblioResolver localM2.setM2compatible(true) localM2.setRoot(m2Path.toURI.toString) @@ -786,20 +800,6 @@ private[spark] object SparkSubmitUtils { sp.setRoot("http://dl.bintray.com/spark-packages/maven") sp.setName("spark-packages") cr.add(sp) - - val repositoryList = remoteRepos.getOrElse("") - // add any other remote repositories other than maven central - if (repositoryList.trim.nonEmpty) { - repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => - val brr: IBiblioResolver = new IBiblioResolver - brr.setM2compatible(true) - brr.setUsepoms(true) - brr.setRoot(repo) - brr.setName(s"repo-${i + 1}") - cr.add(brr) - printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") - } - } cr } @@ -922,6 +922,15 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor + // clear ivy resolution from previous launches. The resolution file is usually at + // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file + // leads to confusion with Ivy when the files can no longer be found at the repository + // declared in that file/ + val mdId = md.getModuleRevisionId + val previousResolution = new File(ivySettings.getDefaultCache, + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") + if (previousResolution.exists) previousResolution.delete + md.setDefaultConf(ivyConfName) // Add exclusion rules for Spark and Scala Library diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 12c40f0b7d658..c9b435a9228d3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -77,9 +77,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) => - if (i > 3) { - assert(resolver.getName === s"repo-${i - 3}") - assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i - 4)) + if (i < 3) { + assert(resolver.getName === s"repo-${i + 1}") + assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i)) } } } From 4a9e03fa850af9e4ee56d011671faa04fb601170 Mon Sep 17 00:00:00 2001 From: Michael Sannella x268 Date: Mon, 29 Jun 2015 17:28:28 -0700 Subject: [PATCH 023/274] [SPARK-8019] [SPARKR] Support SparkR spawning worker R processes with a command other then Rscript This is a simple change to add a new environment variable "spark.sparkr.r.command" that specifies the command that SparkR will use when creating an R engine process. If this is not specified, "Rscript" will be used by default. I did not add any documentation, since I couldn't find any place where environment variables (such as "spark.sparkr.use.daemon") are documented. I also did not add a unit test. The only test that would work generally would be one starting SparkR with sparkR.init(sparkEnvir=list(spark.sparkr.r.command="Rscript")), just using the default value. I think that this is a low-risk change. Likely committers: shivaram Author: Michael Sannella x268 Closes #6557 from msannell/altR and squashes the following commits: 7eac142 [Michael Sannella x268] add spark.sparkr.r.command config parameter --- core/src/main/scala/org/apache/spark/api/r/RRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 4dfa7325934ff..524676544d6f5 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -391,7 +391,7 @@ private[r] object RRDD { } private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { - val rCommand = "Rscript" + val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" val rExecScript = rLibDir + "/SparkR/worker/" + script val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) From 4c1808be4d3aaa37a5a878892e91ca73ea405ffa Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 29 Jun 2015 18:32:31 -0700 Subject: [PATCH 024/274] Revert "[SPARK-8437] [DOCS] Using directory path without wildcard for filename slow for large number of files with wholeTextFiles and binaryFiles" This reverts commit 5d30eae56051c563a8427f330b09ef66db0a0d21. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cb7e24c374152..b3c3bf3746e18 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -831,8 +831,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. - * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory - * rather than `.../path/` or `.../path` * * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @@ -880,11 +878,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @note Small files are preferred; very large files may cause bad performance. - * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory - * rather than `.../path/` or `.../path` - * * @param minPartitions A suggestion value of the minimal splitting number for input data. + * + * @note Small files are preferred; very large files may cause bad performance. */ @Experimental def binaryFiles( From 620605a4a1123afaab2674e38251f1231dea17ce Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 29 Jun 2015 18:40:30 -0700 Subject: [PATCH 025/274] [SPARK-8456] [ML] Ngram featurizer python Python API for N-gram feature transformer Author: Feynman Liang Closes #6960 from feynmanliang/ngram-featurizer-python and squashes the following commits: f9e37c9 [Feynman Liang] Remove debugging code 4dd81f4 [Feynman Liang] Fix typo and doctest 06c79ac [Feynman Liang] Style guide 26c1175 [Feynman Liang] Add python NGram API --- python/pyspark/ml/feature.py | 71 +++++++++++++++++++++++++++++++++++- python/pyspark/ml/tests.py | 11 ++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index ddb33f427ac64..8804dace849b3 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -21,7 +21,7 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer from pyspark.mllib.common import inherit_doc -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder', +__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel'] @@ -265,6 +265,75 @@ class IDFModel(JavaModel): """ +@inherit_doc +@ignore_unicode_prefix +class NGram(JavaTransformer, HasInputCol, HasOutputCol): + """ + A feature transformer that converts the input array of strings into an array of n-grams. Null + values in the input array are ignored. + It returns an array of n-grams where each n-gram is represented by a space-separated string of + words. + When the input is empty, an empty array is returned. + When the input array length is less than n (number of elements per n-gram), no n-grams are + returned. + + >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])]) + >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams") + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e']) + >>> # Change n-gram length + >>> ngram.setParams(n=4).transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Temporarily modify output column. + >>> ngram.transform(df, {ngram.outputCol: "output"}).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e']) + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Must use keyword arguments to specify params. + >>> ngram.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + # a placeholder to make it appear in the generated doc + n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") + + @keyword_only + def __init__(self, n=2, inputCol=None, outputCol=None): + """ + __init__(self, n=2, inputCol=None, outputCol=None) + """ + super(NGram, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) + self.n = Param(self, "n", "number of elements per n-gram (>=1)") + self._setDefault(n=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, n=2, inputCol=None, outputCol=None): + """ + setParams(self, n=2, inputCol=None, outputCol=None) + Sets params for this NGram. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setN(self, value): + """ + Sets the value of :py:attr:`n`. + """ + self._paramMap[self.n] = value + return self + + def getN(self): + """ + Gets the value of n or its default value. + """ + return self.getOrDefault(self.n) + + @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6adbf166f34a8..c151d21fd661a 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -252,6 +252,17 @@ def test_idf(self): output = idf0m.transform(dataset) self.assertIsNotNone(output.head().idf) + def test_ngram(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + ([["a", "b", "c", "d", "e"]])], ["input"]) + ngram0 = NGram(n=4, inputCol="input", outputCol="output") + self.assertEqual(ngram0.getN(), 4) + self.assertEqual(ngram0.getInputCol(), "input") + self.assertEqual(ngram0.getOutputCol(), "output") + transformedDF = ngram0.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + if __name__ == "__main__": unittest.main() From ecacb1e88a135c802e253793e7c863d6ca8d2408 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 18:48:28 -0700 Subject: [PATCH 026/274] [SPARK-8715] ArrayOutOfBoundsException fixed for DataFrameStatSuite.crosstab cc yhuai Author: Burak Yavuz Closes #7100 from brkyvz/ct-flakiness-fix and squashes the following commits: abc299a [Burak Yavuz] change 'to' to until 7e96d7c [Burak Yavuz] ArrayOutOfBoundsException fixed for DataFrameStatSuite.crosstab --- .../test/scala/org/apache/spark/sql/DataFrameStatSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 64ec1a70c47e6..765094da6bda7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -78,7 +78,7 @@ class DataFrameStatSuite extends SparkFunSuite { val rows = crosstab.collect() rows.foreach { row => val i = row.getString(0).toInt - for (col <- 1 to 9) { + for (col <- 1 until columnNames.length) { val j = columnNames(col).toInt assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) } From 4915e9e3bffb57eac319ef2173b4a6ae4073d25e Mon Sep 17 00:00:00 2001 From: Steven She Date: Mon, 29 Jun 2015 18:50:09 -0700 Subject: [PATCH 027/274] [SPARK-8669] [SQL] Fix crash with BINARY (ENUM) fields with Parquet 1.7 Patch to fix crash with BINARY fields with ENUM original types. Author: Steven She Closes #7048 from stevencanopy/SPARK-8669 and squashes the following commits: 2e72979 [Steven She] [SPARK-8669] [SQL] Fix crash with BINARY (ENUM) fields with Parquet 1.7 --- .../spark/sql/parquet/CatalystSchemaConverter.scala | 2 +- .../org/apache/spark/sql/parquet/ParquetSchemaSuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 4fd3e93b70311..2be7c64612cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -177,7 +177,7 @@ private[parquet] class CatalystSchemaConverter( case BINARY => field.getOriginalType match { - case UTF8 => StringType + case UTF8 | ENUM => StringType case null if assumeBinaryIsString => StringType case null => BinaryType case DECIMAL => makeDecimalType() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index d0bfcde7e032b..35d3c33f99a06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -161,6 +161,14 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { """.stripMargin, binaryAsString = true) + testSchemaInference[Tuple1[String]]( + "binary enum as string", + """ + |message root { + | optional binary _1 (ENUM); + |} + """.stripMargin) + testSchemaInference[Tuple1[Seq[Int]]]( "non-nullable array - non-standard", """ From f9b6bf2f83d9dad273aa36d65d0560d35b941cc2 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 29 Jun 2015 18:50:23 -0700 Subject: [PATCH 028/274] [SPARK-7667] [MLLIB] MLlib Python API consistency check MLlib Python API consistency check Author: Yanbo Liang Closes #6856 from yanboliang/spark-7667 and squashes the following commits: 21bae35 [Yanbo Liang] remove duplicate code eb12f95 [Yanbo Liang] fix doc inherit problem 9e7ec3c [Yanbo Liang] address comments e763d32 [Yanbo Liang] MLlib Python API consistency check --- python/pyspark/mllib/feature.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index f00bb93b7bf40..b5138773fd61b 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -111,6 +111,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): """ def transform(self, vector): + """ + Applies transformation on a vector or an RDD[Vector]. + + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param vector: Vector or RDD of Vector to be transformed. + """ if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: @@ -191,7 +200,7 @@ def fit(self, dataset): Computes the mean and variance and stores as a model to be used for later scaling. - :param data: The data used to compute the mean and variance + :param dataset: The data used to compute the mean and variance to build the transformation model. :return: a StandardScalarModel """ @@ -346,10 +355,6 @@ def transform(self, x): vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ - if isinstance(x, RDD): - return JavaVectorTransformer.transform(self, x) - - x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) def idf(self): From 7bbbe380c52419cd580d1c99c10131184e4ad440 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Jun 2015 21:32:40 -0700 Subject: [PATCH 029/274] [SPARK-5161] Parallelize Python test execution This commit parallelizes the Python unit test execution, significantly reducing Jenkins build times. Parallelism is now configurable by passing the `-p` or `--parallelism` flags to either `dev/run-tests` or `python/run-tests` (the default parallelism is 4, but I've successfully tested with higher parallelism). To avoid flakiness, I've disabled the Spark Web UI for the Python tests, similar to what we've done for the JVM tests. Author: Josh Rosen Closes #7031 from JoshRosen/parallelize-python-tests and squashes the following commits: feb3763 [Josh Rosen] Re-enable other tests f87ea81 [Josh Rosen] Only log output from failed tests d4ded73 [Josh Rosen] Logging improvements a2717e1 [Josh Rosen] Make parallelism configurable via dev/run-tests 1bacf1b [Josh Rosen] Merge remote-tracking branch 'origin/master' into parallelize-python-tests 110cd9d [Josh Rosen] Fix universal_newlines for Python 3 cd13db8 [Josh Rosen] Also log python_implementation 9e31127 [Josh Rosen] Log Python --version output for each executable. a2b9094 [Josh Rosen] Bump up parallelism. 5552380 [Josh Rosen] Python 3 fix 866b5b9 [Josh Rosen] Fix lazy logging warnings in Prospector checks 87cb988 [Josh Rosen] Skip MLLib tests for PyPy 8309bfe [Josh Rosen] Temporarily disable parallelism to debug a failure 9129027 [Josh Rosen] Disable Spark UI in Python tests 037b686 [Josh Rosen] Temporarily disable JVM tests so we can test Python speedup in Jenkins. af4cef4 [Josh Rosen] Initial attempt at parallelizing Python test execution --- dev/run-tests | 2 +- dev/run-tests.py | 24 +++++++- dev/sparktestsupport/shellutils.py | 1 + python/pyspark/java_gateway.py | 2 + python/run-tests.py | 97 +++++++++++++++++++++++------- 5 files changed, 101 insertions(+), 25 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index a00d9f0c27639..257d1e8d50bb4 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -20,4 +20,4 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -exec python -u ./dev/run-tests.py +exec python -u ./dev/run-tests.py "$@" diff --git a/dev/run-tests.py b/dev/run-tests.py index e5c897b94d167..4596e07014733 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -19,6 +19,7 @@ from __future__ import print_function import itertools +from optparse import OptionParser import os import re import sys @@ -360,12 +361,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): run_scala_tests_sbt(test_modules, test_profiles) -def run_python_tests(test_modules): +def run_python_tests(test_modules, parallelism): set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS") command = [os.path.join(SPARK_HOME, "python", "run-tests")] if test_modules != [modules.root]: command.append("--modules=%s" % ','.join(m.name for m in test_modules)) + command.append("--parallelism=%i" % parallelism) run_cmd(command) @@ -379,7 +381,25 @@ def run_sparkr_tests(): print("Ignoring SparkR tests as R was not found in PATH") +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") + return opts + + def main(): + opts = parse_opts() # Ensure the user home directory (HOME) is valid and is an absolute directory if not USER_HOME or not os.path.isabs(USER_HOME): print("[error] Cannot determine your home directory as an absolute path;", @@ -461,7 +481,7 @@ def main(): modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: - run_python_tests(modules_with_python_tests) + run_python_tests(modules_with_python_tests, opts.parallelism) if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index ad9b0cc89e4ab..12bd0bf3a4fe9 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -15,6 +15,7 @@ # limitations under the License. # +from __future__ import print_function import os import shutil import subprocess diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3cee4ea6e3a35..90cd342a6cf7f 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -51,6 +51,8 @@ def launch_gateway(): on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + if os.environ.get("SPARK_TESTING"): + submit_args = "--conf spark.ui.enabled=false " + submit_args command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us diff --git a/python/run-tests.py b/python/run-tests.py index 7d485b500ee3a..aaa35e936a806 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -18,12 +18,19 @@ # from __future__ import print_function +import logging from optparse import OptionParser import os import re import subprocess import sys +import tempfile +from threading import Thread, Lock import time +if sys.version < '3': + import Queue +else: + import queue as Queue # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -43,34 +50,44 @@ def print_red(text): LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") +FAILURE_REPORTING_LOCK = Lock() +LOGGER = logging.getLogger() def run_individual_python_test(test_name, pyspark_python): env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} - print(" Running test: %s ..." % test_name, end='') + LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() - with open(LOG_FILE, 'a') as log_file: - retcode = subprocess.call( - [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], - stderr=log_file, stdout=log_file, env=env) + per_test_output = tempfile.TemporaryFile() + retcode = subprocess.Popen( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=per_test_output, stdout=per_test_output, env=env).wait() duration = time.time() - start_time # Exit on the first failure. if retcode != 0: - with open(LOG_FILE, 'r') as log_file: - for line in log_file: + with FAILURE_REPORTING_LOCK: + with open(LOG_FILE, 'ab') as log_file: + per_test_output.seek(0) + log_file.writelines(per_test_output.readlines()) + per_test_output.seek(0) + for line in per_test_output: if not re.match('[0-9]+', line): print(line, end='') - print_red("\nHad test failures in %s; see logs." % test_name) - exit(-1) + per_test_output.close() + print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) else: - print("ok (%is)" % duration) + per_test_output.close() + LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] if "python2.6" not in python_execs: - print("WARNING: Not testing against `python2.6` because it could not be found; falling" - " back to `python` instead") + LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + " back to `python` instead") python_execs.insert(0, "python") return python_execs @@ -88,16 +105,31 @@ def parse_opts(): default=",".join(sorted(python_modules.keys())), help="A comma-separated list of Python modules to test (default: %default)" ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + parser.add_option( + "--verbose", action="store_true", + help="Enable additional debug logging" + ) (opts, args) = parser.parse_args() if args: parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") return opts def main(): opts = parse_opts() - print("Running PySpark tests. Output is in python/%s" % LOG_FILE) + if (opts.verbose): + log_level = logging.DEBUG + else: + log_level = logging.INFO + logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") + LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') @@ -108,24 +140,45 @@ def main(): else: print("Error: unrecognized module %s" % module_name) sys.exit(-1) - print("Will test against the following Python executables: %s" % python_execs) - print("Will test the following Python modules: %s" % [x.name for x in modules_to_test]) + LOGGER.info("Will test against the following Python executables: %s", python_execs) + LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) - start_time = time.time() + task_queue = Queue.Queue() for python_exec in python_execs: python_implementation = subprocess.check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() - print("Testing with `%s`: " % python_exec, end='') - subprocess.call([python_exec, "--version"]) - + LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) + LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output( + [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: - print("Running %s tests ..." % module.name) for test_goal in module.python_test_goals: - run_individual_python_test(test_goal, python_exec) + task_queue.put((python_exec, test_goal)) + + def process_queue(task_queue): + while True: + try: + (python_exec, test_goal) = task_queue.get_nowait() + except Queue.Empty: + break + try: + run_individual_python_test(test_goal, python_exec) + finally: + task_queue.task_done() + + start_time = time.time() + for _ in range(opts.parallelism): + worker = Thread(target=process_queue, args=(task_queue,)) + worker.daemon = True + worker.start() + try: + task_queue.join() + except (KeyboardInterrupt, SystemExit): + print_red("Exiting due to interrupt") + sys.exit(-1) total_duration = time.time() - start_time - print("Tests passed in %i seconds" % total_duration) + LOGGER.info("Tests passed in %i seconds", total_duration) if __name__ == "__main__": From ea775b0662b952849ac7fe2026fc3fd4714c37e3 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 29 Jun 2015 21:41:59 -0700 Subject: [PATCH 030/274] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #1767 (close requested by 'andrewor14') Closes #6952 (close requested by 'andrewor14') Closes #7051 (close requested by 'andrewor14') Closes #5357 (close requested by 'marmbrus') Closes #5233 (close requested by 'andrewor14') Closes #6930 (close requested by 'JoshRosen') Closes #5502 (close requested by 'andrewor14') Closes #6778 (close requested by 'andrewor14') Closes #7006 (close requested by 'andrewor14') From f79410c49b2225b2acdc58293574860230987775 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Jun 2015 22:32:43 -0700 Subject: [PATCH 031/274] [SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes. Author: Reynold Xin Closes #7109 from rxin/auto-cast and squashes the following commits: a914cc3 [Reynold Xin] [SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes. --- .../catalyst/analysis/HiveTypeCoercion.scala | 8 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 118 ++++++++---------- .../spark/sql/catalyst/expressions/misc.scala | 6 +- .../sql/catalyst/expressions/predicates.scala | 6 +- .../expressions/stringOperations.scala | 10 +- 6 files changed, 71 insertions(+), 79 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 976fa57cb98d5..c3d68197d64ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -116,7 +116,7 @@ trait HiveTypeCoercion { IfCoercion :: Division :: PropagateTypes :: - ExpectedInputConversion :: + AddCastForAutoCastInputTypes :: Nil /** @@ -709,15 +709,15 @@ trait HiveTypeCoercion { /** * Casts types according to the expected input types for Expressions that have the trait - * `ExpectsInputTypes`. + * [[AutoCastInputTypes]]. */ - object ExpectedInputConversion extends Rule[LogicalPlan] { + object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { case (child, actual, expected) => if (actual == expected) child else Cast(child, expected) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f59db3d5dfc23..e5dc7b9b5c884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -261,7 +261,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. */ -trait ExpectsInputTypes { +trait AutoCastInputTypes { self: Expression => def expectedChildTypes: Seq[DataType] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 4b57ddd9c5768..a022f3727bd58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -56,7 +56,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { + extends UnaryExpression with Serializable with AutoCastInputTypes { self: Product => override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) @@ -99,7 +99,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -211,19 +211,11 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia } case class Bin(child: Expression) - extends UnaryExpression with Serializable with ExpectsInputTypes { - - val name: String = "BIN" - - override def foldable: Boolean = child.foldable - override def nullable: Boolean = true - override def toString: String = s"$name($child)" + extends UnaryExpression with Serializable with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType - def funcName: String = name.toLowerCase - override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { @@ -239,61 +231,13 @@ case class Bin(child: Expression) } } -//////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Binary math functions -//////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -case class Atan2(left: Expression, right: Expression) - extends BinaryMathExpression(math.atan2, "ATAN2") { - - override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, - evalE2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - } -} - -case class Pow(left: Expression, right: Expression) - extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - } -} /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. - * Otherwise if the number is a STRING, - * it converts each character into its hexadecimal representation and returns the resulting STRING. - * Negative numbers would be treated as two's complement. + * Otherwise if the number is a STRING, it converts each character into its hex representation + * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) - extends UnaryExpression with Serializable { +case class Hex(child: Expression) extends UnaryExpression with Serializable { override def dataType: DataType = StringType @@ -337,7 +281,7 @@ case class Hex(child: Expression) private def doHex(bytes: Array[Byte], length: Int): UTF8String = { val value = new Array[Byte](length * 2) var i = 0 - while(i < length) { + while (i < length) { value(i * 2) = Character.toUpperCase(Character.forDigit( (bytes(i) & 0xF0) >>> 4, 16)).toByte value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( @@ -362,6 +306,54 @@ case class Hex(child: Expression) } } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Binary math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 9a39165a1ff05..27805bff293f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String * For input of type [[BinaryType]] */ case class Md5(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = StringType @@ -61,7 +61,7 @@ case class Md5(child: Expression) * the hash length is not one of the permitted values, the return value is NULL. */ case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ExpectsInputTypes { + extends BinaryExpression with Serializable with AutoCastInputTypes { override def dataType: DataType = StringType @@ -146,7 +146,7 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ -case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3a12d03ba6bb9..386cf6a8df6df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -70,7 +70,7 @@ trait PredicateHelper { } -case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { +case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" @@ -123,7 +123,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -172,7 +172,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index a6225fdafedde..ce184e4f32f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ExpectsInputTypes { +trait StringRegexExpression extends AutoCastInputTypes { self: BinaryExpression => def escape(v: String): String @@ -111,7 +111,7 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } -trait CaseConversionExpression extends ExpectsInputTypes { +trait CaseConversionExpression extends AutoCastInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -158,7 +158,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ExpectsInputTypes { +trait StringComparison extends AutoCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -221,7 +221,7 @@ case class EndsWith(left: Expression, right: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with AutoCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -295,7 +295,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = IntegerType override def expectedChildTypes: Seq[DataType] = Seq(StringType) From e6c3f7462b3fde220ec0084b52388dd4dabb75b9 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Mon, 29 Jun 2015 22:34:38 -0700 Subject: [PATCH 032/274] [SPARK-8650] [SQL] Use the user-specified app name priority in SparkSQLCLIDriver or HiveThriftServer2 When run `./bin/spark-sql --name query1.sql` [Before] ![before](https://cloud.githubusercontent.com/assets/1400819/8370336/fa20b75a-1bf8-11e5-9171-040049a53240.png) [After] ![after](https://cloud.githubusercontent.com/assets/1400819/8370189/dcc35cb4-1bf6-11e5-8796-a0694140bffb.png) Author: Yadong Qi Closes #7030 from watermen/SPARK-8650 and squashes the following commits: 51b5134 [Yadong Qi] Improve code and add comment. e3d7647 [Yadong Qi] use spark.app.name priority. --- .../apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 79eda1f5123bf..1d41c46131828 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -38,9 +38,14 @@ private[hive] object SparkSQLEnv extends Logging { val sparkConf = new SparkConf(loadDefaults = true) val maybeSerializer = sparkConf.getOption("spark.serializer") val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") + // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of + // the default appName [SparkSQLCLIDriver] in cli or beeline. + val maybeAppName = sparkConf + .getOption("spark.app.name") + .filterNot(_ == classOf[SparkSQLCLIDriver].getName) sparkConf - .setAppName(s"SparkSQL::${Utils.localHostName()}") + .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) .set( "spark.serializer", maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer")) From 6c5a6db4d53d6db8aa3464ea6713cf0d3a3bdfb5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Jun 2015 23:08:51 -0700 Subject: [PATCH 033/274] [SPARK-5161] [HOTFIX] Fix bug in Python test failure reporting This patch fixes a bug introduced in #7031 which can cause Jenkins to incorrectly report a build with failed Python tests as passing if an error occurred while printing the test failure message. Author: Josh Rosen Closes #7112 from JoshRosen/python-tests-hotfix and squashes the following commits: c3f2961 [Josh Rosen] Hotfix for bug in Python test failure reporting --- python/run-tests.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/python/run-tests.py b/python/run-tests.py index aaa35e936a806..b7737650daa54 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -58,22 +58,33 @@ def run_individual_python_test(test_name, pyspark_python): env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() - per_test_output = tempfile.TemporaryFile() - retcode = subprocess.Popen( - [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], - stderr=per_test_output, stdout=per_test_output, env=env).wait() + try: + per_test_output = tempfile.TemporaryFile() + retcode = subprocess.Popen( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=per_test_output, stdout=per_test_output, env=env).wait() + except: + LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(1) duration = time.time() - start_time # Exit on the first failure. if retcode != 0: - with FAILURE_REPORTING_LOCK: - with open(LOG_FILE, 'ab') as log_file: + try: + with FAILURE_REPORTING_LOCK: + with open(LOG_FILE, 'ab') as log_file: + per_test_output.seek(0) + log_file.writelines(per_test_output) per_test_output.seek(0) - log_file.writelines(per_test_output.readlines()) - per_test_output.seek(0) - for line in per_test_output: - if not re.match('[0-9]+', line): - print(line, end='') - per_test_output.close() + for line in per_test_output: + decoded_line = line.decode() + if not re.match('[0-9]+', decoded_line): + print(decoded_line, end='') + per_test_output.close() + except: + LOGGER.exception("Got an exception while trying to print failed test output") + finally: print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if # this code is invoked from a thread other than the main thread. From 12671dd5e468beedc2681ff2bdf95fba81f8f29c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 29 Jun 2015 23:44:11 -0700 Subject: [PATCH 034/274] [SPARK-8434][SQL]Add a "pretty" parameter to the "show" method to display long strings Sometimes the user may want to show the complete content of cells. Now `sql("set -v").show()` displays: ![screen shot 2015-06-18 at 4 34 51 pm](https://cloud.githubusercontent.com/assets/1000778/8227339/14d3c5ea-15d9-11e5-99b9-f00b7e93beef.png) The user needs to use something like `sql("set -v").collect().foreach(r => r.toSeq.mkString("\t"))` to show the complete content. This PR adds a `pretty` parameter to show. If `pretty` is false, `show` won't truncate strings or align cells right. ![screen shot 2015-06-18 at 4 21 44 pm](https://cloud.githubusercontent.com/assets/1000778/8227407/b6f8dcac-15d9-11e5-8219-8079280d76fc.png) Author: zsxwing Closes #6877 from zsxwing/show and squashes the following commits: 22e28e9 [zsxwing] pretty -> truncate e582628 [zsxwing] Add pretty parameter to the show method in R a3cd55b [zsxwing] Fix calling showString in R 923cee4 [zsxwing] Add a "pretty" parameter to show to display long strings --- R/pkg/R/DataFrame.R | 4 +- python/pyspark/sql/dataframe.py | 7 ++- .../org/apache/spark/sql/DataFrame.scala | 55 ++++++++++++++++--- .../org/apache/spark/sql/DataFrameSuite.scala | 21 +++++++ 4 files changed, 76 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 6feabf4189c2d..60702824acb46 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -169,8 +169,8 @@ setMethod("isLocal", #'} setMethod("showDF", signature(x = "DataFrame"), - function(x, numRows = 20) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows)) + function(x, numRows = 20, truncate = TRUE) { + s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate) cat(s) }) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 152b87351db31..4b9efa0a210fb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -247,9 +247,12 @@ def isLocal(self): return self._jdf.isLocal() @since(1.3) - def show(self, n=20): + def show(self, n=20, truncate=True): """Prints the first ``n`` rows to the console. + :param n: Number of rows to show. + :param truncate: Whether truncate long strings and align cells right. + >>> df DataFrame[age: int, name: string] >>> df.show() @@ -260,7 +263,7 @@ def show(self, n=20): | 5| Bob| +---+-----+ """ - print(self._jdf.showString(n)) + print(self._jdf.showString(n, truncate)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 986e59133919f..8fe1f7e34cb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -169,8 +169,9 @@ class DataFrame private[sql]( /** * Internal API for Python * @param _numRows Number of rows to show + * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(_numRows: Int): String = { + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { val numRows = _numRows.max(0) val sb = new StringBuilder val takeResult = take(numRows + 1) @@ -188,7 +189,7 @@ class DataFrame private[sql]( case seq: Seq[_] => seq.mkString("[", ", ", "]") case _ => cell.toString } - if (str.length > 20) str.substring(0, 17) + "..." else str + if (truncate && str.length > 20) str.substring(0, 17) + "..." else str }: Seq[String] } @@ -207,7 +208,11 @@ class DataFrame private[sql]( // column names rows.head.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") sb.append(sep) @@ -215,7 +220,11 @@ class DataFrame private[sql]( // data rows.tail.map { _.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell.toString, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") } @@ -331,7 +340,8 @@ class DataFrame private[sql]( def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] /** - * Displays the [[DataFrame]] in a tabular form. For example: + * Displays the [[DataFrame]] in a tabular form. Strings more than 20 characters will be + * truncated, and all cells will be aligned right. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -345,15 +355,46 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def show(numRows: Int): Unit = println(showString(numRows)) + def show(numRows: Int): Unit = show(numRows, true) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. * @group action * @since 1.3.0 */ def show(): Unit = show(20) + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + /** * Returns a [[DataFrameNaFunctions]] for working with missing data. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d06b9c5785527..50d324c0686fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -492,6 +492,27 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show(1000) } + test("showString: truncate = [true, false]") { + val longString = Array.fill(21)("1").mkString + val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = """+---------------------+ + ||_1 | + |+---------------------+ + ||1 | + ||111111111111111111111| + |+---------------------+ + |""".stripMargin + assert(df.showString(10, false) === expectedAnswerForFalse) + val expectedAnswerForTrue = """+--------------------+ + || _1| + |+--------------------+ + || 1| + ||11111111111111111...| + |+--------------------+ + |""".stripMargin + assert(df.showString(10, true) === expectedAnswerForTrue) + } + test("showString(negative)") { val expectedAnswer = """+---+-----+ ||key|value| From 5452457410ffe881773f2f2cdcdc752467b19720 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Mon, 29 Jun 2015 23:50:34 -0700 Subject: [PATCH 035/274] [SPARK-8551] [ML] Elastic net python code example Author: Shuo Xiang Closes #6946 from coderxiang/en-java-code-example and squashes the following commits: 7a4bdf8 [Shuo Xiang] address comments cddb02b [Shuo Xiang] add elastic net python example code f4fa534 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 6ad4865 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 180b496 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' aa0717d [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 5f109b4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' c5c5bfe [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 98804c9 [Shuo Xiang] fix bug in topBykey and update test --- .../src/main/python/ml/logistic_regression.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 examples/src/main/python/ml/logistic_regression.py diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py new file mode 100644 index 0000000000000..55afe1b207fe0 --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.evaluation import MulticlassMetrics +from pyspark.ml.feature import StringIndexer +from pyspark.mllib.util import MLUtils +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a logistic regression with elastic net regularization Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression.py +""" + +if __name__ == "__main__": + + if len(sys.argv) > 1: + print("Usage: logistic_regression", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [training, test] = td.randomSplit([0.7, 0.3]) + + lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") + lr.setElasticNetParam(0.8) + + # Fit the model + lrModel = lr.fit(training) + + predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + sc.stop() From 2ed0c0ac4686ea779f98713978e37b97094edc1c Mon Sep 17 00:00:00 2001 From: Tim Ellison Date: Tue, 30 Jun 2015 13:49:52 +0100 Subject: [PATCH 036/274] [SPARK-7756] [CORE] More robust SSL options processing. Subset the enabled algorithms in an SSLOptions to the elements that are supported by the protocol provider. Update the list of ciphers in the sample config to include modern algorithms, and specify both Oracle and IBM names. In practice the user would either specify their own chosen cipher suites, or specify none, and delegate the decision to the provider. Author: Tim Ellison Closes #7043 from tellison/SSLEnhancements and squashes the following commits: 034efa5 [Tim Ellison] Ensure Java imports are grouped and ordered by package. 3797f8b [Tim Ellison] Remove unnecessary use of Option to improve clarity, and fix import style ordering. 4b5c89f [Tim Ellison] More robust SSL options processing. --- .../scala/org/apache/spark/SSLOptions.scala | 43 ++++++++++++++++--- .../org/apache/spark/SSLOptionsSuite.scala | 20 ++++++--- .../org/apache/spark/SSLSampleConfigs.scala | 24 ++++++++--- .../apache/spark/SecurityManagerSuite.scala | 21 ++++++--- 4 files changed, 85 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 2cdc167f85af0..32df42d57dbd6 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -17,7 +17,9 @@ package org.apache.spark -import java.io.File +import java.io.{File, FileInputStream} +import java.security.{KeyStore, NoSuchAlgorithmException} +import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory} import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -38,7 +40,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * @param trustStore a path to the trust-store file * @param trustStorePassword a password to access the trust-store file * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java - * @param enabledAlgorithms a set of encryption algorithms to use + * @param enabledAlgorithms a set of encryption algorithms that may be used */ private[spark] case class SSLOptions( enabled: Boolean = false, @@ -48,7 +50,8 @@ private[spark] case class SSLOptions( trustStore: Option[File] = None, trustStorePassword: Option[String] = None, protocol: Option[String] = None, - enabledAlgorithms: Set[String] = Set.empty) { + enabledAlgorithms: Set[String] = Set.empty) + extends Logging { /** * Creates a Jetty SSL context factory according to the SSL settings represented by this object. @@ -63,7 +66,7 @@ private[spark] case class SSLOptions( trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) keyPassword.foreach(sslContextFactory.setKeyManagerPassword) protocol.foreach(sslContextFactory.setProtocol) - sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) Some(sslContextFactory) } else { @@ -94,7 +97,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { @@ -102,6 +105,36 @@ private[spark] case class SSLOptions( } } + /* + * The supportedAlgorithms set is a subset of the enabledAlgorithms that + * are supported by the current Java security provider for this protocol. + */ + private val supportedAlgorithms: Set[String] = { + var context: SSLContext = null + try { + context = SSLContext.getInstance(protocol.orNull) + /* The set of supported algorithms does not depend upon the keys, trust, or + rng, although they will influence which algorithms are eventually used. */ + context.init(null, null, null) + } catch { + case npe: NullPointerException => + logDebug("No SSL protocol specified") + context = SSLContext.getDefault + case nsa: NoSuchAlgorithmException => + logDebug(s"No support for requested SSL protocol ${protocol.get}") + context = SSLContext.getDefault + } + + val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet + + // Log which algorithms we are discarding + (enabledAlgorithms &~ providerAlgorithms).foreach { cipher => + logDebug(s"Discarding unsupported cipher $cipher") + } + + enabledAlgorithms & providerAlgorithms + } + /** Returns a string representation of this SSLOptions with all the passwords masked. */ override def toString: String = s"SSLOptions{enabled=$enabled, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 376481ba541fa..25b79bce6ab98 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import javax.net.ssl.SSLContext import com.google.common.io.Files import org.apache.spark.util.Utils @@ -29,6 +30,15 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + // Pick two cipher suites that the provider knows about + val sslContext = SSLContext.getInstance("TLSv1.2") + sslContext.init(null, null, null) + val algorithms = sslContext + .getServerSocketFactory + .getDefaultCipherSuites + .take(2) + .toSet + val conf = new SparkConf conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) @@ -36,9 +46,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") - conf.set("spark.ssl.protocol", "SSLv3") + conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) + conf.set("spark.ssl.protocol", "TLSv1.2") val opts = SSLOptions.parse(conf, "spark.ssl") @@ -52,9 +61,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(opts.trustStorePassword === Some("password")) assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) - assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.protocol === Some("TLSv1.2")) + assert(opts.enabledAlgorithms === algorithms) } test("test resolving property with defaults specified ") { diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 1a099da2c6c8e..33270bec6247c 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -25,6 +25,20 @@ object SSLSampleConfigs { this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + val enabledAlgorithms = + // A reasonable set of TLSv1.2 Oracle security provider suites + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "TLS_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + + // and their equivalent names in the IBM Security provider + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "SSL_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" + def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") @@ -33,9 +47,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } @@ -47,9 +60,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e9b64aa82a17a..f34aefca4eb18 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -127,6 +127,17 @@ class SecurityManagerSuite extends SparkFunSuite { test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() + val expectedAlgorithms = Set( + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "TLS_RSA_WITH_AES_256_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "SSL_RSA_WITH_AES_256_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") val securityManager = new SecurityManager(conf) @@ -143,9 +154,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") @@ -154,9 +164,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) - assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.akkaSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.akkaSSLOptions.enabledAlgorithms === expectedAlgorithms) } test("ssl off setup") { From 08fab4843845136358f3a7251e8d90135126b419 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Jun 2015 07:58:49 -0700 Subject: [PATCH 037/274] [SPARK-8590] [SQL] add code gen for ExtractValue TODO: use array instead of Seq as internal representation for `ArrayType` Author: Wenchen Fan Closes #6982 from cloud-fan/extract-value and squashes the following commits: e203bc1 [Wenchen Fan] address comments 4da0f0b [Wenchen Fan] some clean up f679969 [Wenchen Fan] fix bug e64f942 [Wenchen Fan] remove generic e3f8427 [Wenchen Fan] fix style and address comments fc694e8 [Wenchen Fan] add code gen for extract value --- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 46 ++++-- .../catalyst/expressions/ExtractValue.scala | 76 ++++++++-- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 15 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 13 +- .../sql/catalyst/expressions/predicates.scala | 3 - .../spark/sql/catalyst/expressions/sets.scala | 4 - .../spark/sql/catalyst/util/TypeUtils.scala | 2 +- .../expressions/ComplexTypeSuite.scala | 131 ++++++++++-------- 11 files changed, 199 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5db2fcfcb267b..dc0b4ac5cd9bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -47,7 +47,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) s""" boolean ${ev.isNull} = i.isNullAt($ordinal); ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e5dc7b9b5c884..aed48921bdeb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -179,9 +179,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe + /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * * @param f accepts two variable names and returns Java code to compute the output. @@ -190,15 +191,23 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express ctx: CodeGenContext, ev: GeneratedExpressionCode, f: (String, String) => String): String = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (left.dataType != right.dataType) { - // log.warn(s"${left.dataType} != ${right.dataType}") - } + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val resultCode = f(eval1.primitive, eval2.primitive) - + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; @@ -206,7 +215,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express if (!${ev.isNull}) { ${eval2.code} if (!${eval2.isNull}) { - ${ev.primitive} = $resultCode; + $resultCode } else { ${ev.isNull} = true; } @@ -245,13 +254,26 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio ctx: CodeGenContext, ev: GeneratedExpressionCode, f: String => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s"$result = ${f(eval)};" + }) + } + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { val eval = child.gen(ctx) - // reuse the previous isNull - ev.isNull = eval.isNull + val resultCode = f(ev.primitive, eval.primitive) eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = ${f(eval.primitive)}; + $resultCode } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4d7c95ffd1850..3020e7fc967f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -21,6 +21,7 @@ import scala.collection.Map import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ object ExtractValue { @@ -38,7 +39,7 @@ object ExtractValue { def apply( child: Expression, extraction: Expression, - resolver: Resolver): ExtractValue = { + resolver: Resolver): Expression = { (child.dataType, extraction) match { case (StructType(fields), NonNullLiteral(v, StringType)) => @@ -73,7 +74,7 @@ object ExtractValue { def unapply(g: ExtractValue): Option[(Expression, Expression)] = { g match { case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) - case _ => Some((g.child, null)) + case s: ExtractValueWithStruct => Some((s.child, null)) } } @@ -101,11 +102,11 @@ object ExtractValue { * Note: concrete extract value expressions are created only by `ExtractValue.apply`, * we don't need to do type check for them. */ -trait ExtractValue extends UnaryExpression { - self: Product => +trait ExtractValue { + self: Expression => } -abstract class ExtractValueWithStruct extends ExtractValue { +abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue { self: Product => def field: StructField @@ -125,6 +126,18 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) val baseValue = child.eval(input).asInstanceOf[InternalRow] if (baseValue == null) null else baseValue(ordinal) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + $result = ${ctx.getColumn(eval, dataType, ordinal)}; + } + """ + }) + } } /** @@ -137,6 +150,7 @@ case class GetArrayStructFields( containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable || containsNull || field.nullable override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] @@ -146,18 +160,39 @@ case class GetArrayStructFields( } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = "scala.collection.mutable.ArraySeq" + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + final int n = $eval.size(); + final $arraySeqClass values = new $arraySeqClass(n); + for (int j = 0; j < n; j++) { + InternalRow row = (InternalRow) $eval.apply(j); + if (row != null && !row.isNullAt($ordinal)) { + values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + } + } + $result = (${ctx.javaType(dataType)}) values; + """ + }) + } } -abstract class ExtractValueWithOrdinal extends ExtractValue { +abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue { self: Product => def ordinal: Expression + def child: Expression + + override def left: Expression = child + override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true - override def foldable: Boolean = child.foldable && ordinal.foldable override def toString: String = s"$child[$ordinal]" - override def children: Seq[Expression] = child :: ordinal :: Nil override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -195,6 +230,19 @@ case class GetArrayItem(child: Expression, ordinal: Expression) baseValue(index) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + final int index = (int)$eval2; + if (index >= $eval1.size() || index < 0) { + ${ev.isNull} = true; + } else { + $result = (${ctx.boxedType(dataType)})$eval1.apply(index); + } + """ + }) + } } /** @@ -209,4 +257,16 @@ case class GetMapValue(child: Expression, ordinal: Expression) val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + if ($eval1.contains($eval2)) { + $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2); + } else { + ${ev.isNull} = true; + } + """ + }) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 3d4d9e2d798f0..ae765c1653203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -82,8 +82,6 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => - /** Name of the function for this expression on a [[Decimal]] type. */ - def decimalMethod: String = "" override def dataType: DataType = left.dataType @@ -113,6 +111,10 @@ abstract class BinaryArithmetic extends BinaryExpression { } } + /** Name of the function for this expression on a [[Decimal]] type. */ + def decimalMethod: String = + sys.error("BinaryArithmetics must override either decimalMethod or genCode") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 57e0bede5db20..bf6a6a124088e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -82,24 +82,24 @@ class CodeGenContext { /** * Returns the code to access a column in Row for a given DataType. */ - def getColumn(dataType: DataType, ordinal: Int): String = { + def getColumn(row: String, dataType: DataType, ordinal: Int): String = { val jt = javaType(dataType) if (isPrimitiveType(jt)) { - s"i.get${primitiveTypeName(jt)}($ordinal)" + s"$row.get${primitiveTypeName(jt)}($ordinal)" } else { - s"($jt)i.apply($ordinal)" + s"($jt)$row.apply($ordinal)" } } /** * Returns the code to update a column in Row for a given DataType. */ - def setColumn(dataType: DataType, ordinal: Int, value: String): String = { + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { val jt = javaType(dataType) if (isPrimitiveType(jt)) { - s"set${primitiveTypeName(jt)}($ordinal, $value)" + s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" } else { - s"update($ordinal, $value)" + s"$row.update($ordinal, $value)" } } @@ -127,6 +127,9 @@ class CodeGenContext { case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType + case _: StructType => "InternalRow" + case _: ArrayType => s"scala.collection.Seq" + case _: MapType => s"scala.collection.Map" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 64ef357a4f954..addb8023d9c0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -43,7 +43,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu if(${evaluationCode.isNull}) mutableRow.setNullAt($i); else - mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a022f3727bd58..da63f2fa970cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -78,17 +78,14 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) def funcName: String = name.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + ${ev.primitive} = java.lang.Math.${funcName}($eval); if (Double.valueOf(${ev.primitive}).isNaN()) { ${ev.isNull} = true; } - } - """ + """ + }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 386cf6a8df6df..98cd5aa8148c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,10 +69,7 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } - case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index efc6f50b78943..daa9f4403ffab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -135,8 +135,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { - override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = left.dataType override def symbol: String = "++=" @@ -185,8 +183,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres */ case class CountSet(child: Expression) extends UnaryExpression { - override def nullable: Boolean = child.nullable - override def dataType: DataType = LongType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 8656cc334d09f..3148309a2166f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types._ /** - * Helper function to check for valid data types + * Helper functions to check for valid data types. */ object TypeUtils { def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b80911e7257fc..3515d044b2f7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -40,51 +40,42 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("GetArrayItem") { + val typeA = ArrayType(StringType) + val array = Literal.create(Seq("a", "b"), typeA) testIntegralDataTypes { convert => - val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b") } + val nullArray = Literal.create(null, typeA) + val nullInt = Literal.create(null, IntegerType) + checkEvaluation(GetArrayItem(nullArray, Literal(1)), null) + checkEvaluation(GetArrayItem(array, nullInt), null) + checkEvaluation(GetArrayItem(nullArray, nullInt), null) + + val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) + checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } - test("CreateStruct") { - val row = InternalRow(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") - checkEvaluation(CreateStruct(Seq(c1, c3)), InternalRow(1, 3), row) + test("GetMapValue") { + val typeM = MapType(StringType, StringType) + val map = Literal.create(Map("a" -> "b"), typeM) + val nullMap = Literal.create(null, typeM) + val nullString = Literal.create(null, StringType) + + checkEvaluation(GetMapValue(map, Literal("a")), "b") + checkEvaluation(GetMapValue(map, nullString), null) + checkEvaluation(GetMapValue(nullMap, nullString), null) + checkEvaluation(GetMapValue(map, nullString), null) + + val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM)) + checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c")) } - test("complex type") { - val row = create_row( - "^Ba*n", // 0 - null.asInstanceOf[UTF8String], // 1 - create_row("aa", "bb"), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - ) - - val typeS = StructType( - StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil - ) - val typeMap = MapType(StringType, StringType) - val typeArray = ArrayType(StringType) - - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal("aa")), "bb", row) - checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) - checkEvaluation( - GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal.create(null, StringType)), null, row) - - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal(1)), "bb", row) - checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) - checkEvaluation( - GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal.create(null, IntegerType)), null, row) - - def getStructField(expr: Expression, fieldName: String): ExtractValue = { + test("GetStructField") { + val typeS = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), typeS) + val nullStruct = Literal.create(null, typeS) + + def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get @@ -92,28 +83,58 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } - def quickResolve(u: UnresolvedExtractValue): ExtractValue = { - ExtractValue(u.child, u.extraction, _ == _) - } + checkEvaluation(getStructField(struct, "a"), 1) + checkEvaluation(getStructField(nullStruct, "a"), null) + + val nestedStruct = Literal.create(create_row(create_row(1)), + StructType(StructField("a", typeS) :: Nil)) + checkEvaluation(getStructField(nestedStruct, "a"), create_row(1)) + + val typeS_fieldNotNullable = StructType(StructField("a", IntegerType, false) :: Nil) + val struct_fieldNotNullable = Literal.create(create_row(1), typeS_fieldNotNullable) + val nullStruct_fieldNotNullable = Literal.create(null, typeS_fieldNotNullable) + + assert(getStructField(struct_fieldNotNullable, "a").nullable === false) + assert(getStructField(struct, "a").nullable === true) + assert(getStructField(nullStruct_fieldNotNullable, "a").nullable === true) + assert(getStructField(nullStruct, "a").nullable === true) + } - checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) + test("GetArrayStructFields") { + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) + val nullArrayStruct = Literal.create(null, typeAS) - val typeS_notNullable = StructType( - StructField("a", StringType, nullable = false) - :: StructField("b", StringType, nullable = false) :: Nil - ) + def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { + expr.dataType match { + case ArrayType(StructType(fields), containsNull) => + val field = fields.find(_.name == fieldName).get + GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + } + } + + checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1)) + checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) + } - assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true) - assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable - === false) + test("CreateStruct") { + val row = create_row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) + } - assert(getStructField(Literal.create(null, typeS), "a").nullable === true) - assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) + test("test dsl for complex type") { + def quickResolve(u: UnresolvedExtractValue): Expression = { + ExtractValue(u.child, u.extraction, _ == _) + } - checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) - checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) - checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) + checkEvaluation(quickResolve('c.map(MapType(StringType, StringType)).at(0).getItem("a")), + "b", create_row(Map("a" -> "b"))) + checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), + "b", create_row(Seq("a", "b"))) + checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + 1, create_row(create_row(1))) } test("error message of ExtractValue") { From 865a834e51ac3074811a11fd99a36d942f7f7de8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Jun 2015 08:08:15 -0700 Subject: [PATCH 038/274] [SPARK-8723] [SQL] improve divide and remainder code gen We can avoid execution of both left and right expression by null and zero check. Author: Wenchen Fan Closes #7111 from cloud-fan/cg and squashes the following commits: d6b12ef [Wenchen Fan] improve divide and remainder code gen --- .../sql/catalyst/expressions/arithmetic.scala | 54 ++++++++++++------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ae765c1653203..5363b3556886a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -216,23 +216,32 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val test = if (left.dataType.isInstanceOf[DecimalType]) { + val isZero = if (dataType.isInstanceOf[DecimalType]) { s"${eval2.primitive}.isZero()" } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " - val javaType = ctx.javaType(left.dataType) - eval1.code + eval2.code + - s""" + val javaType = ctx.javaType(dataType) + val divide = if (dataType.isInstanceOf[DecimalType]) { + s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + } else { + s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + } + s""" + ${eval2.code} boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.isNull} || ${eval2.isNull} || $test) { + $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { - ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $divide; + } } - """ + """ } } @@ -273,23 +282,32 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val test = if (left.dataType.isInstanceOf[DecimalType]) { + val isZero = if (dataType.isInstanceOf[DecimalType]) { s"${eval2.primitive}.isZero()" } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " - val javaType = ctx.javaType(left.dataType) - eval1.code + eval2.code + - s""" + val javaType = ctx.javaType(dataType) + val remainder = if (dataType.isInstanceOf[DecimalType]) { + s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + } else { + s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + } + s""" + ${eval2.code} boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.isNull} || ${eval2.isNull} || $test) { + $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { - ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $remainder; + } } - """ + """ } } From a48e61915354d33fb98944a8eb5a5d48dd102041 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Jun 2015 08:17:24 -0700 Subject: [PATCH 039/274] [SPARK-8680] [SQL] Slightly improve PropagateTypes JIRA: https://issues.apache.org/jira/browse/SPARK-8680 This PR slightly improve `PropagateTypes` in `HiveTypeCoercion`. It moves `q.inputSet` outside `q transformExpressions` instead calling `inputSet` multiple times. It also builds a map of attributes for looking attribute easily. Author: Liang-Chi Hsieh Closes #7087 from viirya/improve_propagatetypes and squashes the following commits: 5c314c1 [Liang-Chi Hsieh] For comments. 913f6ad [Liang-Chi Hsieh] Slightly improve PropagateTypes. --- .../catalyst/analysis/HiveTypeCoercion.scala | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index c3d68197d64ac..e525ad623ff12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -131,20 +131,22 @@ trait HiveTypeCoercion { // Don't propagate types from unresolved children. case q: LogicalPlan if !q.childrenResolved => q - case q: LogicalPlan => q transformExpressions { - case a: AttributeReference => - q.inputSet.find(_.exprId == a.exprId) match { - // This can happen when a Attribute reference is born in a non-leaf node, for example - // due to a call to an external script like in the Transform operator. - // TODO: Perhaps those should actually be aliases? - case None => a - // Leave the same if the dataTypes match. - case Some(newType) if a.dataType == newType.dataType => a - case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") - newType - } - } + case q: LogicalPlan => + val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap + q transformExpressions { + case a: AttributeReference => + inputMap.get(a.exprId) match { + // This can happen when a Attribute reference is born in a non-leaf node, for example + // due to a call to an external script like in the Transform operator. + // TODO: Perhaps those should actually be aliases? + case None => a + // Leave the same if the dataTypes match. + case Some(newType) if a.dataType == newType.dataType => a + case Some(newType) => + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + newType + } + } } } From 722aa5f48ec105bf23eee2361adddfe3a0cd6fc4 Mon Sep 17 00:00:00 2001 From: Shilei Date: Tue, 30 Jun 2015 09:49:58 -0700 Subject: [PATCH 040/274] [SPARK-8236] [SQL] misc functions: crc32 https://issues.apache.org/jira/browse/SPARK-8236 Author: Shilei Closes #7108 from qiansl127/Crc32 and squashes the following commits: 5477352 [Shilei] Change to AutoCastInputTypes 5f16e5d [Shilei] Add misc function crc32 --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 40 +++++++++++++++++++ .../expressions/MiscFunctionsSuite.scala | 8 ++++ .../org/apache/spark/sql/functions.scala | 16 ++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 11 +++++ 5 files changed, 76 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b17457d3094c2..d53eaedda56b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -139,6 +139,7 @@ object FunctionRegistry { expression[Sha2]("sha2"), expression[Sha1]("sha1"), expression[Sha1]("sha"), + expression[Crc32]("crc32"), // aggregate functions expression[Average]("avg"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 27805bff293f4..a7bcbe46c339a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.security.MessageDigest import java.security.NoSuchAlgorithmException +import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -168,3 +169,42 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp ) } } + +/** + * A function that computes a cyclic redundancy check value and returns it as a bigint + * For input of type [[BinaryType]] + */ +case class Crc32(child: Expression) + extends UnaryExpression with AutoCastInputTypes { + + override def dataType: DataType = LongType + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + val checksum = new CRC32 + checksum.update(value.asInstanceOf[Array[Byte]], 0, value.asInstanceOf[Array[Byte]].length) + checksum.getValue + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val value = child.gen(ctx) + val CRC32 = "java.util.zip.CRC32" + s""" + ${value.code} + boolean ${ev.isNull} = ${value.isNull}; + long ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${CRC32} checksum = new ${CRC32}(); + checksum.update(${value.primitive}, 0, ${value.primitive}.length); + ${ev.primitive} = checksum.getValue(); + } + """ + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 36e636b5da6b8..b524d0af14a67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -49,4 +49,12 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) } + + test("crc32") { + checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L) + checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + 2180413220L) + checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4d9a019058228..6331fe61052ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1466,6 +1466,22 @@ object functions { */ def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) + /** + * Calculates the cyclic redundancy check value and returns the value as a bigint. + * + * @group misc_funcs + * @since 1.5.0 + */ + def crc32(e: Column): Column = Crc32(e.expr) + + /** + * Calculates the cyclic redundancy check value and returns the value as a bigint. + * + * @group misc_funcs + * @since 1.5.0 + */ + def crc32(columnName: String): Column = crc32(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index abfd47c811ed9..11a8767ead96c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -173,6 +173,17 @@ class DataFrameFunctionsSuite extends QueryTest { } } + test("misc crc32 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(crc32($"a"), crc32("b")), + Row(2743272264L, 2180413220L)) + + checkAnswer( + df.selectExpr("crc32(a)", "crc32(b)"), + Row(2743272264L, 2180413220L)) + } + test("string length function") { checkAnswer( nullStrings.select(strlen($"s"), strlen("s")), From 689da28a53cf720ae607a1a935093612a7001615 Mon Sep 17 00:00:00 2001 From: xuchenCN Date: Tue, 30 Jun 2015 10:05:51 -0700 Subject: [PATCH 041/274] [SPARK-8592] [CORE] CoarseGrainedExecutorBackend: Cannot register with driver => NPE Look detail of this issue at [SPARK-8592](https://issues.apache.org/jira/browse/SPARK-8592) **CoarseGrainedExecutorBackend** should exit when **RegisterExecutor** failed Author: xuchenCN Closes #7110 from xuchenCN/SPARK-8592 and squashes the following commits: 71e0077 [xuchenCN] [SPARK-8592] [CORE] CoarseGrainedExecutorBackend: Cannot register with driver => NPE --- .../apache/spark/executor/CoarseGrainedExecutorBackend.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index f3a26f54a81fb..34d4cfdca7732 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -66,7 +66,10 @@ private[spark] class CoarseGrainedExecutorBackend( case Success(msg) => Utils.tryLogNonFatalError { Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor } - case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + case Failure(e) => { + logError(s"Cannot register with driver: $driverUrl", e) + System.exit(1) + } }(ThreadUtils.sameThread) } From ada384b785c663392a0b69fad5bfe7a0a0584ee0 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 30 Jun 2015 10:07:26 -0700 Subject: [PATCH 042/274] [SPARK-8437] [DOCS] Corrected: Using directory path without wildcard for filename slow for large number of files with wholeTextFiles and binaryFiles Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' (now fixed scaladoc by using HTML entity for *) Author: Sean Owen Closes #7126 from srowen/SPARK-8437.2 and squashes the following commits: 7bb45da [Sean Owen] Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' (now fixed scaladoc by using HTML entity for *) --- core/src/main/scala/org/apache/spark/SparkContext.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b3c3bf3746e18..0e5a86f44e410 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -831,7 +831,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. - * + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def wholeTextFiles( @@ -878,9 +879,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental def binaryFiles( From 45281664e0d3b22cd63660ca8ad6dd574f10e21f Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 30 Jun 2015 10:25:59 -0700 Subject: [PATCH 043/274] [SPARK-4127] [MLLIB] [PYSPARK] Python bindings for StreamingLinearRegressionWithSGD Python bindings for StreamingLinearRegressionWithSGD Author: MechCoder Closes #6744 from MechCoder/spark-4127 and squashes the following commits: d8f6457 [MechCoder] Moved StreamingLinearAlgorithm to pyspark.mllib.regression d47cc24 [MechCoder] Inherit from StreamingLinearAlgorithm 1b4ddd6 [MechCoder] minor 4de6c68 [MechCoder] Minor refactor 5e85a3b [MechCoder] Add tests for simultaneous training and prediction fb27889 [MechCoder] Add example and docs 505380b [MechCoder] Add tests d42bdae [MechCoder] [SPARK-4127] Python bindings for StreamingLinearRegressionWithSGD --- docs/mllib-linear-methods.md | 52 +++++++++++ python/pyspark/mllib/classification.py | 50 +--------- python/pyspark/mllib/regression.py | 90 ++++++++++++++++++ python/pyspark/mllib/tests.py | 124 ++++++++++++++++++++++++- 4 files changed, 269 insertions(+), 47 deletions(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 3dc8cc902fa72..2a2a7c13186d8 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -768,6 +768,58 @@ will get better! +
+ +First, we import the necessary classes for parsing our input data and creating the model. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +{% endhighlight %} + +Then we make input streams for training and testing data. We assume a StreamingContext `ssc` +has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) +for more info. For this example, we use labeled points in training and testing streams, +but in practice you will likely want to use unlabeled vectors for test data. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create our model by initializing the weights to 0 + +{% highlight python %} +numFeatures = 3 +model = StreamingLinearRegressionWithSGD() +model.setInitialWeights([0.0, 0.0, 0.0]) +{% endhighlight %} + +Now we register the streams for training and testing and start the job. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} + +We can now save text files with data to the training or testing folders. +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions +will get better! + +
+ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 735d45ba03d27..8f27c446a66e8 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -24,7 +24,9 @@ from pyspark.streaming import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector -from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper +from pyspark.mllib.regression import ( + LabeledPoint, LinearModel, _regression_train_wrapper, + StreamingLinearAlgorithm) from pyspark.mllib.util import Saveable, Loader, inherit_doc @@ -585,55 +587,13 @@ def train(cls, data, lambda_=1.0): return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) -class StreamingLinearAlgorithm(object): - """ - Base class that has to be inherited by any StreamingLinearAlgorithm. - - Prevents reimplementation of methods predictOn and predictOnValues. - """ - def __init__(self, model): - self._model = model - - def latestModel(self): - """ - Returns the latest model. - """ - return self._model - - def _validate(self, dstream): - if not isinstance(dstream, DStream): - raise TypeError( - "dstream should be a DStream object, got %s" % type(dstream)) - if not self._model: - raise ValueError( - "Model must be intialized using setInitialWeights") - - def predictOn(self, dstream): - """ - Make predictions on a dstream. - - :return: Transformed dstream object. - """ - self._validate(dstream) - return dstream.map(lambda x: self._model.predict(x)) - - def predictOnValues(self, dstream): - """ - Make predictions on a keyed dstream. - - :return: Transformed dstream object. - """ - self._validate(dstream) - return dstream.mapValues(lambda x: self._model.predict(x)) - - @inherit_doc class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LogisticRegression with SGD on a stream of data. + Run LogisticRegression with SGD on a batch of data. The weights obtained at the end of training a stream are used as initial - weights for the next stream. + weights for the next batch. :param stepSize: Step size for each iteration of gradient descent. :param numIterations: Number of iterations run for each batch of data. diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 5ddbbee4babdd..8e90adee5f4c2 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -19,6 +19,7 @@ from numpy import array from pyspark import RDD +from pyspark.streaming.dstream import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector from pyspark.mllib.util import Saveable, Loader @@ -570,6 +571,95 @@ def train(cls, data, isotonic=True): return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic) +class StreamingLinearAlgorithm(object): + """ + Base class that has to be inherited by any StreamingLinearAlgorithm. + + Prevents reimplementation of methods predictOn and predictOnValues. + """ + def __init__(self, model): + self._model = model + + def latestModel(self): + """ + Returns the latest model. + """ + return self._model + + def _validate(self, dstream): + if not isinstance(dstream, DStream): + raise TypeError( + "dstream should be a DStream object, got %s" % type(dstream)) + if not self._model: + raise ValueError( + "Model must be intialized using setInitialWeights") + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + +@inherit_doc +class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): + """ + Run LinearRegression with SGD on a batch of data. + + The problem minimized is (1 / n_samples) * (y - weights'X)**2. + After training on a batch of data, the weights obtained at the end of + training are used as initial weights for the next batch. + + :param: stepSize Step size for each iteration of gradient descent. + :param: numIterations Total number of iterations run. + :param: miniBatchFraction Fraction of data on which SGD is run for each + iteration. + """ + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0): + self.stepSize = stepSize + self.numIterations = numIterations + self.miniBatchFraction = miniBatchFraction + self._model = None + super(StreamingLinearRegressionWithSGD, self).__init__( + model=self._model) + + def setInitialWeights(self, initialWeights): + """ + Set the initial value of weights. + + This must be set before running trainOn and predictOn + """ + initialWeights = _convert_to_vector(initialWeights) + self._model = LinearRegressionModel(initialWeights, 0) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + # LinearRegressionWithSGD.train raises an error for an empty RDD. + if not rdd.isEmpty(): + self._model = LinearRegressionWithSGD.train( + rdd, self.numIterations, self.stepSize, + self.miniBatchFraction, self._model.weights, + self._model.intercept) + + dstream.foreachRDD(update) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index cd80c3e07a4f7..f0091d6faccce 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -27,8 +27,9 @@ from shutil import rmtree from numpy import ( - array, array_equal, zeros, inf, random, exp, dot, all, mean) + array, array_equal, zeros, inf, random, exp, dot, all, mean, abs) from numpy import sum as array_sum + from py4j.protocol import Py4JJavaError if sys.version_info[:2] <= (2, 6): @@ -45,8 +46,8 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT -from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec @@ -56,6 +57,7 @@ from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext +from pyspark.streaming import StreamingContext _have_scipy = False try: @@ -1170,6 +1172,124 @@ def collect_errors(rdd): self.assertTrue(errors[1] - errors[-1] > 0.3) +class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): + + def assertArrayAlmostEqual(self, array1, array2, dec): + for i, j in array1, array2: + self.assertAlmostEqual(i, j, dec) + + def test_parameter_accuracy(self): + """Test that coefs are predicted accurately by fitting on toy data.""" + + # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients + # (10, 10) + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0, 0.0]) + xMean = [0.0, 0.0] + xVariance = [1.0 / 3.0, 1.0 / 3.0] + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + input_stream = self.ssc.queueStream(batches) + t = time() + slr.trainOn(input_stream) + self.ssc.start() + self._ssc_wait(t, 10, 0.01) + self.assertArrayAlmostEqual( + slr.latestModel().weights.array, [10., 10.], 1) + self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + + def test_parameter_convergence(self): + """Test that the model parameters improve with streaming data.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + input_stream.foreachRDD( + lambda x: model_weights.append(slr.latestModel().weights[0])) + t = time() + slr.trainOn(input_stream) + self.ssc.start() + self._ssc_wait(t, 10, 0.01) + + model_weights = array(model_weights) + diff = model_weights[1:] - model_weights[:-1] + self.assertTrue(all(diff >= -0.1)) + + def test_prediction(self): + """Test prediction on a model with weights already set.""" + # Create a model with initial Weights equal to coefs + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([10.0, 10.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], + 100, 42 + i, 0.1) + batches.append( + sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + + input_stream = self.ssc.queueStream(batches) + t = time() + output_stream = slr.predictOnValues(input_stream) + samples = [] + output_stream.foreachRDD(lambda x: samples.append(x.collect())) + + self.ssc.start() + self._ssc_wait(t, 5, 0.01) + + # Test that mean absolute error on each batch is less than 0.1 + for batch in samples: + true, predicted = zip(*batch) + self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) + + def test_train_prediction(self): + """Test that error on test data improves as model is trained.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in batches] + mean_absolute_errors = [] + + def func(rdd): + true, predicted = zip(*rdd.collect()) + mean_absolute_errors.append(mean(abs(true) - abs(predicted))) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + output_stream = self.ssc.queueStream(predict_batches) + t = time() + slr.trainOn(input_stream) + output_stream = slr.predictOnValues(output_stream) + output_stream.foreachRDD(func) + self.ssc.start() + self._ssc_wait(t, 10, 0.01) + self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") From 5fa0863626aaf5a9a41756a0b1ec82bddccbf067 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 30 Jun 2015 10:27:29 -0700 Subject: [PATCH 044/274] [SPARK-8679] [PYSPARK] [MLLIB] Default values in Pipeline API should be immutable It might be dangerous to have a mutable as value for default param. (http://stackoverflow.com/a/11416002/1170730) e.g def func(example, f={}): f[example] = 1 return f func(2) {2: 1} func(3) {2:1, 3:1} mengxr Author: MechCoder Closes #7058 from MechCoder/pipeline_api_playground and squashes the following commits: 40a5eb2 [MechCoder] copy 95f7ff2 [MechCoder] [SPARK-8679] [PySpark] [MLlib] Default values in Pipeline API should be immutable --- python/pyspark/ml/pipeline.py | 24 ++++++++++++++++++------ python/pyspark/ml/wrapper.py | 4 +++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a563024b2cdcb..9889f56cac9e4 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -42,7 +42,7 @@ def _fit(self, dataset): """ raise NotImplementedError() - def fit(self, dataset, params={}): + def fit(self, dataset, params=None): """ Fits a model to the input dataset with optional parameters. @@ -54,6 +54,8 @@ def fit(self, dataset, params={}): list of models. :returns: fitted model(s) """ + if params is None: + params = dict() if isinstance(params, (list, tuple)): return [self.fit(dataset, paramMap) for paramMap in params] elif isinstance(params, dict): @@ -86,7 +88,7 @@ def _transform(self, dataset): """ raise NotImplementedError() - def transform(self, dataset, params={}): + def transform(self, dataset, params=None): """ Transforms the input dataset with optional parameters. @@ -96,6 +98,8 @@ def transform(self, dataset, params={}): params. :returns: transformed dataset """ + if params is None: + params = dict() if isinstance(params, dict): if params: return self.copy(params,)._transform(dataset) @@ -135,10 +139,12 @@ class Pipeline(Estimator): """ @keyword_only - def __init__(self, stages=[]): + def __init__(self, stages=None): """ __init__(self, stages=[]) """ + if stages is None: + stages = [] super(Pipeline, self).__init__() #: Param for pipeline stages. self.stages = Param(self, "stages", "pipeline stages") @@ -162,11 +168,13 @@ def getStages(self): return self._paramMap[self.stages] @keyword_only - def setParams(self, stages=[]): + def setParams(self, stages=None): """ setParams(self, stages=[]) Sets params for Pipeline. """ + if stages is None: + stages = [] kwargs = self.setParams._input_kwargs return self._set(**kwargs) @@ -195,7 +203,9 @@ def _fit(self, dataset): transformers.append(stage) return PipelineModel(transformers) - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() that = Params.copy(self, extra) stages = [stage.copy(extra) for stage in that.getStages()] return that.setStages(stages) @@ -216,6 +226,8 @@ def _transform(self, dataset): dataset = t.transform(dataset) return dataset - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() stages = [stage.copy(extra) for stage in self.stages] return PipelineModel(stages) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 7b0893e2cdadc..253705bde913e 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -166,7 +166,7 @@ def __init__(self, java_model): self._java_obj = java_model self.uid = java_model.uid() - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and @@ -175,6 +175,8 @@ def copy(self, extra={}): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() that = super(JavaModel, self).copy(extra) that._java_obj = self._java_obj.copy(self._empty_java_param_map()) that._transfer_params_to_java() From fbb267ed6fe799a58f88c2fba2d41e954e5f1547 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Jun 2015 10:48:49 -0700 Subject: [PATCH 045/274] [SPARK-8713] Make codegen thread safe Codegen takes three steps: 1. Take a list of expressions, convert them into Java source code and a list of expressions that don't not support codegen (fallback to interpret mode). 2. Compile the Java source into Java class (bytecode) 3. Using the Java class and the list of expression to build a Projection. Currently, we cache the whole three steps, the key is a list of expression, result is projection. Because some of expressions (which may not thread-safe, for example, Random) will be hold by the Projection, the projection maybe not thread safe. This PR change to only cache the second step, then we can build projection using codegen even some expressions are not thread-safe, because the cache will not hold any expression anymore. cc marmbrus rxin JoshRosen Author: Davies Liu Closes #7101 from davies/codegen_safe and squashes the following commits: 7dd41f1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into codegen_safe 847bd08 [Davies Liu] don't use scala.refect 4ddaaed [Davies Liu] Merge branch 'master' of github.com:apache/spark into codegen_safe 1793cf1 [Davies Liu] make codegen thread safe --- .../sql/catalyst/expressions/Expression.scala | 14 ----------- .../sql/catalyst/expressions/ScalaUDF.scala | 3 --- .../expressions/codegen/CodeGenerator.scala | 25 ++++++++++--------- .../codegen/GenerateOrdering.scala | 9 +++---- .../codegen/GenerateProjection.scala | 7 +++--- .../expressions/namedExpressions.scala | 2 -- .../catalyst/expressions/nullFunctions.scala | 2 -- .../spark/sql/execution/SparkPlan.scala | 6 ++--- .../MonotonicallyIncreasingID.scala | 2 -- .../apache/spark/sql/sources/commands.scala | 2 +- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 --- 12 files changed, 24 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index aed48921bdeb5..b5063f32fa529 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -60,14 +60,6 @@ abstract class Expression extends TreeNode[Expression] { /** Returns the result of evaluating this expression on a given input Row */ def eval(input: InternalRow = null): Any - /** - * Return true if this expression is thread-safe, which means it could be used by multiple - * threads in the same time. - * - * An expression that is not thread-safe can not be cached and re-used, especially for codegen. - */ - def isThreadSafe: Boolean = true - /** * Returns an [[GeneratedExpressionCode]], which contains Java source code that * can be used to generate the result of evaluating the expression on an input row. @@ -76,9 +68,6 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - if (!isThreadSafe) { - throw new Exception(s"$this is not thread-safe, can not be used in codegen") - } val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) @@ -178,8 +167,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" - override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe - /** * Short hand for generating binary evaluation code. * If either of the sub-expressions is null, the result of this computation @@ -237,7 +224,6 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable - override def isThreadSafe: Boolean = child.isThreadSafe /** * Called by unary expressions to generate a code block that returns null if its parent returns diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index dbb4381d54c4f..ebabb6f117851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -956,7 +956,4 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi // scalastyle:on private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) override def eval(input: InternalRow): Any = converter(f(input)) - - // TODO(davies): make ScalaUDF work with codegen - override def isThreadSafe: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bf6a6a124088e..a64027e48a00b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -235,11 +235,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** * Compile the Java source code into a Java class, using Janino. - * - * It will track the time used to compile */ protected def compile(code: String): GeneratedClass = { - val startTime = System.nanoTime() + cache.get(code) + } + + /** + * Compile the Java source code into a Java class, using Janino. + */ + private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow")) @@ -251,9 +255,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin logError(s"failed to compile:\n $code", e) throw e } - val endTime = System.nanoTime() - def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } @@ -266,16 +267,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * automatically, in order to constrain its memory footprint. Note that this cache does not use * weak keys/values and thus does not respond to memory pressure. */ - protected val cache = CacheBuilder.newBuilder() + private val cache = CacheBuilder.newBuilder() .maximumSize(100) .build( - new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = { + new CacheLoader[String, GeneratedClass]() { + override def load(code: String): GeneratedClass = { val startTime = System.nanoTime() - val result = create(in) + val result = doCompile(code) val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logInfo(s"Code generated expression $in in $timeMs ms") + logInfo(s"Code generated in $timeMs ms") result } }) @@ -285,7 +286,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin generate(bind(expressions, inputSchema)) /** Generates the requested evaluator given already bound expression(s). */ - def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) + def generate(expressions: InType): OutType = create(canonicalize(expressions)) /** * Create a new codegen context for expression evaluator, used to store those diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 7ed2c5addec9b..97cb16045ae4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -38,7 +38,6 @@ class BaseOrdering extends Ordering[InternalRow] { */ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging { - import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) @@ -47,8 +46,6 @@ object GenerateOrdering in.map(BindReferences.bindReference(_, inputSchema)) protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { - val a = newTermName("a") - val b = newTermName("b") val ctx = newCodeGenContext() val comparisons = ordering.zipWithIndex.map { case (order, i) => @@ -56,9 +53,9 @@ object GenerateOrdering val evalB = order.child.gen(ctx) val asc = order.direction == Ascending s""" - i = $a; + i = a; ${evalA.code} - i = $b; + i = b; ${evalB.code} if (${evalA.isNull} && ${evalB.isNull}) { // Nothing @@ -80,7 +77,7 @@ object GenerateOrdering return new SpecificOrdering(expr); } - class SpecificOrdering extends ${typeOf[BaseOrdering]} { + class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private $exprType[] expressions = null; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 39d32b78cc14a..5be47175fa7f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -32,7 +32,6 @@ abstract class BaseProject extends Projection {} * primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -157,7 +156,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificProjection(expr); } - class SpecificProjection extends ${typeOf[BaseProject]} { + class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; public SpecificProjection($exprType[] expr) { @@ -170,7 +169,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } } - final class SpecificRow extends ${typeOf[MutableRow]} { + final class SpecificRow extends ${classOf[MutableRow].getName} { $columns @@ -224,7 +223,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public InternalRow copy() { Object[] arr = new Object[${expressions.length}]; ${copyColumns} - return new ${typeOf[GenericInternalRow]}(arr); + return new ${classOf[GenericInternalRow].getName}(arr); } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 6f56a9ec7beb5..81ebda3060c51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -117,8 +117,6 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) - override def isThreadSafe: Boolean = child.isThreadSafe - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 5d5911403ece1..78be2824347d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -51,8 +51,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def isThreadSafe: Boolean = children.forall(_.isThreadSafe) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { s""" boolean ${ev.isNull} = true; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 47f56b2b7ebe6..7739a9f949c77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -156,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled && expressions.forall(_.isThreadSafe)) { + if (codegenEnabled) { GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) @@ -168,7 +168,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ inputSchema: Seq[Attribute]): () => MutableProjection = { log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if(codegenEnabled && expressions.forall(_.isThreadSafe)) { + if(codegenEnabled) { GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) @@ -178,7 +178,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled && expression.isThreadSafe) { + if (codegenEnabled) { GeneratePredicate.generate(expression, inputSchema) } else { InterpretedPredicate.create(expression, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 3b217348b7b7a..68914cf85cb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -48,6 +48,4 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { count += 1 (TaskContext.get().partitionId().toLong << 33) + currentCount } - - override def isThreadSafe: Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 54c8eeb41a8ea..42b51caab5ce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -270,7 +270,7 @@ private[sql] case class InsertIntoHadoopFsRelation( inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled && expressions.forall(_.isThreadSafe)) { + if (codegenEnabled) { GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7005c7079af91..0b875304f9b0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -591,7 +591,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio rdd.map(_.asInstanceOf[InternalRow]) } converted.mapPartitions { rows => - val buildProjection = if (codegenEnabled && requiredOutput.forall(_.isThreadSafe)) { + val buildProjection = if (codegenEnabled) { GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) } else { () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index d7827d56ca8c5..4dea561ae5f60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -120,8 +120,6 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) - override def isThreadSafe: Boolean = false - // TODO: Finish input output types. override def eval(input: InternalRow): Any = { unwrap( @@ -180,8 +178,6 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr lazy val dataType: DataType = inspectorToDataType(returnInspector) - override def isThreadSafe: Boolean = false - override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. From 9213f73a8ea09ae343af825a6b576c212cf4a0c7 Mon Sep 17 00:00:00 2001 From: Tijo Thomas Date: Tue, 30 Jun 2015 10:50:45 -0700 Subject: [PATCH 046/274] [SPARK-8615] [DOCUMENTATION] Fixed Sample deprecated code Modified the deprecated jdbc api in the documentation. Author: Tijo Thomas Closes #7039 from tijoparacka/JIRA_8615 and squashes the following commits: 6e73b8a [Tijo Thomas] Reverted new lines 4042fcf [Tijo Thomas] updated to sql documentation a27949c [Tijo Thomas] Fixed Sample deprecated code --- docs/sql-programming-guide.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2786e3d2cd6bf..88c96a9a095b3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1773,9 +1773,9 @@ the Data Sources API. The following options are supported:
{% highlight scala %} -val jdbcDF = sqlContext.load("jdbc", Map( - "url" -> "jdbc:postgresql:dbserver", - "dbtable" -> "schema.tablename")) +val jdbcDF = sqlContext.read.format("jdbc").options( + Map("url" -> "jdbc:postgresql:dbserver", + "dbtable" -> "schema.tablename")).load() {% endhighlight %}
@@ -1788,7 +1788,7 @@ Map options = new HashMap(); options.put("url", "jdbc:postgresql:dbserver"); options.put("dbtable", "schema.tablename"); -DataFrame jdbcDF = sqlContext.load("jdbc", options) +DataFrame jdbcDF = sqlContext.read().format("jdbc"). options(options).load(); {% endhighlight %} @@ -1798,7 +1798,7 @@ DataFrame jdbcDF = sqlContext.load("jdbc", options) {% highlight python %} -df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") +df = sqlContext.read.format('jdbc').options(url = 'jdbc:postgresql:dbserver', dbtable='schema.tablename').load() {% endhighlight %} From ca7e460f7d6fb898dc29236a85520bbe954c8a13 Mon Sep 17 00:00:00 2001 From: nishkamravi2 Date: Tue, 30 Jun 2015 11:12:15 -0700 Subject: [PATCH 047/274] [SPARK-7988] [STREAMING] Round-robin scheduling of receivers by default Minimal PR for round-robin scheduling of receivers. Dense scheduling can be enabled by setting preferredLocation, so a new config parameter isn't really needed. Tested this on a cluster of 6 nodes and noticed 20-25% gain in throughput compared to random scheduling. tdas pwendell Author: nishkamravi2 Author: Nishkam Ravi Closes #6607 from nishkamravi2/master_nravi and squashes the following commits: 1918819 [Nishkam Ravi] Update ReceiverTrackerSuite.scala f747739 [Nishkam Ravi] Update ReceiverTrackerSuite.scala 6127e58 [Nishkam Ravi] Update ReceiverTracker and ReceiverTrackerSuite 9f1abc2 [nishkamravi2] Update ReceiverTrackerSuite.scala ae29152 [Nishkam Ravi] Update test suite with TD's suggestions 48a4a97 [nishkamravi2] Update ReceiverTracker.scala bc23907 [nishkamravi2] Update ReceiverTracker.scala 68e8540 [nishkamravi2] Update SchedulerSuite.scala 4604f28 [nishkamravi2] Update SchedulerSuite.scala 179b90f [nishkamravi2] Update ReceiverTracker.scala 242e677 [nishkamravi2] Update SchedulerSuite.scala 7f3e028 [Nishkam Ravi] Update ReceiverTracker.scala, add unit test cases in SchedulerSuite f8a3e05 [nishkamravi2] Update ReceiverTracker.scala 4cf97b6 [nishkamravi2] Update ReceiverTracker.scala 16e84ec [Nishkam Ravi] Update ReceiverTracker.scala 45e3a99 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi 02dbdb8 [Nishkam Ravi] Update ReceiverTracker.scala 07b9dfa [nishkamravi2] Update ReceiverTracker.scala 6caeefe [nishkamravi2] Update ReceiverTracker.scala 7888257 [nishkamravi2] Update ReceiverTracker.scala 6e3515c [Nishkam Ravi] Minor changes 975b8d8 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi 3cac21b [Nishkam Ravi] Generalize the scheduling algorithm b05ee2f [nishkamravi2] Update ReceiverTracker.scala bb5e09b [Nishkam Ravi] Add a new var in receiver to store location information for round-robin scheduling 41705de [nishkamravi2] Update ReceiverTracker.scala fff1b2e [Nishkam Ravi] Round-robin scheduling of streaming receivers --- .../streaming/scheduler/ReceiverTracker.scala | 64 ++++++++++--- .../scheduler/ReceiverTrackerSuite.scala | 90 +++++++++++++++++++ 2 files changed, 141 insertions(+), 13 deletions(-) create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index e6cdbec11e94c..644e581cd8279 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, SynchronizedMap} +import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap} import scala.language.existentials +import scala.math.max +import org.apache.spark.rdd._ import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{Logging, SparkEnv, SparkException} @@ -272,6 +274,41 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } + /** + * Get the list of executors excluding driver + */ + private def getExecutors(ssc: StreamingContext): List[String] = { + val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(":")(0)).toList + val driver = ssc.sparkContext.getConf.get("spark.driver.host") + executors.diff(List(driver)) + } + + /** Set host location(s) for each receiver so as to distribute them over + * executors in a round-robin fashion taking into account preferredLocation if set + */ + private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]], + executors: List[String]): Array[ArrayBuffer[String]] = { + val locations = new Array[ArrayBuffer[String]](receivers.length) + var i = 0 + for (i <- 0 until receivers.length) { + locations(i) = new ArrayBuffer[String]() + if (receivers(i).preferredLocation.isDefined) { + locations(i) += receivers(i).preferredLocation.get + } + } + var count = 0 + for (i <- 0 until max(receivers.length, executors.length)) { + if (!receivers(i % receivers.length).preferredLocation.isDefined) { + locations(i % receivers.length) += executors(count) + count += 1 + if (count == executors.length) { + count = 0 + } + } + } + locations + } + /** * Get the receivers from the ReceiverInputDStreams, distributes them to the * worker nodes as a parallel collection, and runs them. @@ -283,18 +320,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false rcvr }) - // Right now, we only honor preferences if all receivers have them - val hasLocationPreferences = receivers.map(_.preferredLocation.isDefined).reduce(_ && _) - - // Create the parallel collection of receivers to distributed them on the worker nodes - val tempRDD = - if (hasLocationPreferences) { - val receiversWithPreferences = receivers.map(r => (r, Seq(r.preferredLocation.get))) - ssc.sc.makeRDD[Receiver[_]](receiversWithPreferences) - } else { - ssc.sc.makeRDD(receivers, receivers.size) - } - val checkpointDirOption = Option(ssc.checkpointDir) val serializableHadoopConf = new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) @@ -311,12 +336,25 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false supervisor.start() supervisor.awaitTermination() } + // Run the dummy Spark job to ensure that all slaves have registered. // This avoids all the receivers to be scheduled on the same node. if (!ssc.sparkContext.isLocal) { ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() } + // Get the list of executors and schedule receivers + val executors = getExecutors(ssc) + val tempRDD = + if (!executors.isEmpty) { + val locations = scheduleReceivers(receivers, executors) + val roundRobinReceivers = (0 until receivers.length).map(i => + (receivers(i), locations(i))) + ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers) + } else { + ssc.sc.makeRDD(receivers, receivers.size) + } + // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") running = true diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala new file mode 100644 index 0000000000000..a6e783861dbe6 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.streaming._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receiver._ +import org.apache.spark.util.Utils + +/** Testsuite for receiver scheduling */ +class ReceiverTrackerSuite extends TestSuiteBase { + val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") + val ssc = new StreamingContext(sparkConf, Milliseconds(100)) + val tracker = new ReceiverTracker(ssc) + val launcher = new tracker.ReceiverLauncher() + val executors: List[String] = List("0", "1", "2", "3") + + test("receiver scheduling - all or none have preferred location") { + + def parse(s: String): Array[Array[String]] = { + val outerSplit = s.split("\\|") + val loc = new Array[Array[String]](outerSplit.length) + var i = 0 + for (i <- 0 until outerSplit.length) { + loc(i) = outerSplit(i).split("\\,") + } + loc + } + + def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) { + val receivers = + if (preferredLocation) { + Array.tabulate(numReceivers)(i => new DummyReceiver(host = + Some(((i + 1) % executors.length).toString))) + } else { + Array.tabulate(numReceivers)(_ => new DummyReceiver) + } + val locations = launcher.scheduleReceivers(receivers, executors) + val expectedLocations = parse(allocation) + assert(locations.deep === expectedLocations.deep) + } + + testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0") + testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2") + testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0") + } + + test("receiver scheduling - some have preferred location") { + val numReceivers = 4; + val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")), + new DummyReceiver, new DummyReceiver, new DummyReceiver) + val locations = launcher.scheduleReceivers(receivers, executors) + assert(locations(0)(0) === "1") + assert(locations(1)(0) === "0") + assert(locations(2)(0) === "1") + assert(locations(0).length === 1) + assert(locations(3).length === 1) + } +} + +/** + * Dummy receiver implementation + */ +private class DummyReceiver(host: Option[String] = None) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + def onStart() { + } + + def onStop() { + } + + override def preferredLocation: Option[String] = host +} From 57264400ac7d9f9c59c387c252a9ed8d93fed4fa Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Jun 2015 11:14:38 -0700 Subject: [PATCH 048/274] [SPARK-8630] [STREAMING] Prevent from checkpointing QueueInputDStream This PR throws an exception in `QueueInputDStream.writeObject` so that it can fail the application when calling `StreamingContext.start` rather than failing it during recovering QueueInputDStream. Author: zsxwing Closes #7016 from zsxwing/queueStream-checkpoint and squashes the following commits: 89a3d73 [zsxwing] Fix JavaAPISuite.testQueueStream cc40fd7 [zsxwing] Prevent from checkpointing QueueInputDStream --- .../spark/streaming/StreamingContext.scala | 8 ++++++++ .../api/java/JavaStreamingContext.scala | 18 +++++++++++++++--- .../streaming/dstream/QueueInputDStream.scala | 15 ++++++++++----- .../apache/spark/streaming/JavaAPISuite.java | 8 ++++++++ .../streaming/StreamingContextSuite.scala | 15 +++++++++++++++ 5 files changed, 56 insertions(+), 8 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 1708f309fc002..ec49d0f42d122 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -477,6 +477,10 @@ class StreamingContext private[streaming] ( /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -491,6 +495,10 @@ class StreamingContext private[streaming] ( /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 989e3a729ebc2..40deb6d7ea79a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -419,7 +419,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @tparam T Type of objects in the RDD */ @@ -435,7 +439,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -455,7 +463,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index ed7da6dc1315e..a2f5d82a79bd3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,13 +17,14 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.streaming.{Time, StreamingContext} +import java.io.{NotSerializableException, ObjectOutputStream} + +import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag +import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.streaming.{Time, StreamingContext} + private[streaming] class QueueInputDStream[T: ClassTag]( @transient ssc: StreamingContext, @@ -36,6 +37,10 @@ class QueueInputDStream[T: ClassTag]( override def stop() { } + private def writeObject(oos: ObjectOutputStream): Unit = { + throw new NotSerializableException("queueStream doesn't support checkpointing") + } + override def compute(validTime: Time): Option[RDD[T]] = { val buffer = new ArrayBuffer[RDD[T]]() if (oneAtATime && queue.size > 0) { diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 1077b1b2cb7e3..a34f23475804a 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -364,6 +364,14 @@ private void testReduceByWindow(boolean withInverse) { @SuppressWarnings("unchecked") @Test public void testQueueStream() { + ssc.stop(); + // Create a new JavaStreamingContext without checkpointing + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + List> expected = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 819dd2ccfe915..56b4ce5638a51 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.streaming import java.io.{File, NotSerializableException} import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.Queue + import org.apache.commons.io.FileUtils import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts @@ -665,6 +667,19 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo transformed.foreachRDD { rdd => rdd.collect() } } } + test("queueStream doesn't support checkpointing") { + val checkpointDir = Utils.createTempDir() + ssc = new StreamingContext(master, appName, batchDuration) + val rdd = ssc.sparkContext.parallelize(1 to 10) + ssc.queueStream[Int](Queue(rdd)).print() + ssc.checkpoint(checkpointDir.getAbsolutePath) + val e = intercept[NotSerializableException] { + ssc.start() + } + // StreamingContext.validate changes the message, so use "contains" here + assert(e.getMessage.contains("queueStream doesn't support checkpointing")) + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) From d16a9443750eebb7a3d7688d4b98a2ac39cc0da7 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Tue, 30 Jun 2015 11:46:22 -0700 Subject: [PATCH 049/274] [SPARK-8619] [STREAMING] Don't recover keytab and principal configuration within Streaming checkpoint [Client.scala](https://github.com/apache/spark/blob/master/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala#L786) will change these configurations, so this would cause the problem that the Streaming recover logic can't find the local keytab file(since configuration was changed) ```scala sparkConf.set("spark.yarn.keytab", keytabFileName) sparkConf.set("spark.yarn.principal", args.principal) ``` Problem described at [Jira](https://issues.apache.org/jira/browse/SPARK-8619) Author: huangzhaowei Closes #7008 from SaintBacchus/SPARK-8619 and squashes the following commits: d50dbdf [huangzhaowei] Delect one blank space 9b8e92c [huangzhaowei] Fix code style and add a short comment. 0d8f800 [huangzhaowei] Don't recover keytab and principal configuration within Streaming checkpoint. --- .../org/apache/spark/streaming/Checkpoint.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index d8dc4e4101664..5279331c9e122 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -44,11 +44,23 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val sparkConfPairs = ssc.conf.getAll def createSparkConf(): SparkConf = { + + // Reload properties for the checkpoint application since user wants to set a reload property + // or spark had changed its value and user wants to set it back. + val propertiesToReload = List( + "spark.master", + "spark.yarn.keytab", + "spark.yarn.principal") + val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.port") - val newMasterOption = new SparkConf(loadDefaults = true).getOption("spark.master") - newMasterOption.foreach { newMaster => newSparkConf.setMaster(newMaster) } + val newReloadConf = new SparkConf(loadDefaults = true) + propertiesToReload.foreach { prop => + newReloadConf.getOption(prop).foreach { value => + newSparkConf.set(prop, value) + } + } newSparkConf } From 1e1f339976641af4cc87d4010db57c3b600f91af Mon Sep 17 00:00:00 2001 From: Christian Kadner Date: Tue, 30 Jun 2015 12:22:34 -0700 Subject: [PATCH 050/274] [SPARK-6785] [SQL] fix DateTimeUtils for dates before 1970 Hi Michael, this Pull-Request is a follow-up to [PR-6242](https://github.com/apache/spark/pull/6242). I removed the two obsolete test cases from the HiveQuerySuite and deleted the corresponding golden answer files. Thanks for your review! Author: Christian Kadner Closes #6983 from ckadner/SPARK-6785 and squashes the following commits: ab1e79b [Christian Kadner] Merge remote-tracking branch 'origin/SPARK-6785' into SPARK-6785 1fed877 [Christian Kadner] [SPARK-6785][SQL] failed Scala style test, remove spaces on empty line DateTimeUtils.scala:61 9d8021d [Christian Kadner] [SPARK-6785][SQL] merge recent changes in DateTimeUtils & MiscFunctionsSuite b97c3fb [Christian Kadner] [SPARK-6785][SQL] move test case for DateTimeUtils to DateTimeUtilsSuite a451184 [Christian Kadner] [SPARK-6785][SQL] fix DateTimeUtils.fromJavaDate(java.util.Date) for Dates before 1970 --- .../sql/catalyst/util/DateTimeUtils.scala | 8 ++-- .../catalyst/util/DateTimeUtilsSuite.scala | 40 ++++++++++++++++++- .../sql/ScalaReflectionRelationSuite.scala | 2 +- ...te cast-0-a7cd69b80c77a771a2c955db666be53d | 1 - ... test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 | 1 - .../sql/hive/execution/HiveQuerySuite.scala | 14 ------- .../sql/hive/execution/SQLQuerySuite.scala | 31 +++++++++++++- 7 files changed, 75 insertions(+), 22 deletions(-) delete mode 100644 sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d delete mode 100644 sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 640e67e2ecd76..4269ad5d56737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -59,10 +59,12 @@ object DateTimeUtils { } } - // we should use the exact day as Int, for example, (year, month, day) -> day - def millisToDays(millisLocal: Long): Int = { - ((millisLocal + threadLocalLocalTimeZone.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt + def millisToDays(millisUtc: Long): Int = { + // SPARK-6785: use Math.floor so negative number of days (dates before 1970) + // will correctly work as input for function toJavaDate(Int) + val millisLocal = millisUtc.toDouble + threadLocalLocalTimeZone.get().getOffset(millisUtc) + Math.floor(millisLocal / MILLIS_PER_DAY).toInt } // reverse of millisToDays diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 03eb64f097a37..1d4a60c81efc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.util -import java.sql.Timestamp +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat import org.apache.spark.SparkFunSuite @@ -48,4 +49,41 @@ class DateTimeUtilsSuite extends SparkFunSuite { val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1)) assert(t.equals(t2)) } + + test("SPARK-6785: java date conversion before and after epoch") { + def checkFromToJavaDate(d1: Date): Unit = { + val d2 = DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d1)) + assert(d2.toString === d1.toString) + } + + val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val df2 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z") + + checkFromToJavaDate(new Date(100)) + + checkFromToJavaDate(Date.valueOf("1970-01-01")) + + checkFromToJavaDate(new Date(df1.parse("1970-01-01 00:00:00").getTime)) + checkFromToJavaDate(new Date(df2.parse("1970-01-01 00:00:00 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1970-01-01 00:00:01").getTime)) + checkFromToJavaDate(new Date(df2.parse("1970-01-01 00:00:01 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1969-12-31 23:59:59").getTime)) + checkFromToJavaDate(new Date(df2.parse("1969-12-31 23:59:59 UTC").getTime)) + + checkFromToJavaDate(Date.valueOf("1969-01-01")) + + checkFromToJavaDate(new Date(df1.parse("1969-01-01 00:00:00").getTime)) + checkFromToJavaDate(new Date(df2.parse("1969-01-01 00:00:00 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1969-01-01 00:00:01").getTime)) + checkFromToJavaDate(new Date(df2.parse("1969-01-01 00:00:01 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1989-11-09 11:59:59").getTime)) + checkFromToJavaDate(new Date(df2.parse("1989-11-09 19:59:59 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1776-07-04 10:30:00").getTime)) + checkFromToJavaDate(new Date(df2.parse("1776-07-04 18:30:00 UTC").getTime)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 4cb5ba2f0d5eb..ab6d3dd96d271 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -78,7 +78,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) Seq(data).toDF().registerTempTable("reflectData") assert(ctx.sql("SELECT * FROM reflectData").collect().head === diff --git a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d deleted file mode 100644 index 98da82fa89386..0000000000000 --- a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d +++ /dev/null @@ -1 +0,0 @@ -1970-01-01 1970-01-01 1969-12-31 16:00:00 1969-12-31 16:00:00 1970-01-01 00:00:00 diff --git a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 deleted file mode 100644 index 27ba77ddaf615..0000000000000 --- a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 +++ /dev/null @@ -1 +0,0 @@ -true diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 51dabc67fa7c1..4cdba03b27022 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -324,20 +324,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { | FROM src LIMIT 1 """.stripMargin) - createQueryTest("Date comparison test 2", - "SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1") - - createQueryTest("Date cast", - """ - | SELECT - | CAST(CAST(0 AS timestamp) AS date), - | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), - | CAST(0 AS timestamp), - | CAST(CAST(0 AS timestamp) AS string), - | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) - | FROM src LIMIT 1 - """.stripMargin) - createQueryTest("Simple Average", "SELECT AVG(key) FROM src") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9f7e58f890241..6d645393a6da1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.hive.execution +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -962,4 +964,31 @@ class SQLQuerySuite extends QueryTest { case None => // OK } } + + test("SPARK-6785: HiveQuerySuite - Date comparison test 2") { + checkAnswer( + sql("SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1"), + Row(false)) + } + + test("SPARK-6785: HiveQuerySuite - Date cast") { + // new Date(0) == 1970-01-01 00:00:00.0 GMT == 1969-12-31 16:00:00.0 PST + checkAnswer( + sql( + """ + | SELECT + | CAST(CAST(0 AS timestamp) AS date), + | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), + | CAST(0 AS timestamp), + | CAST(CAST(0 AS timestamp) AS string), + | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) + | FROM src LIMIT 1 + """.stripMargin), + Row( + Date.valueOf("1969-12-31"), + String.valueOf("1969-12-31"), + Timestamp.valueOf("1969-12-31 16:00:00"), + String.valueOf("1969-12-31 16:00:00"), + Timestamp.valueOf("1970-01-01 00:00:00"))) + } } From c1befd780c3defc843baa75097de7ec427d3f8ca Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 30 Jun 2015 12:23:48 -0700 Subject: [PATCH 051/274] [SPARK-8664] [ML] Add PCA transformer Add PCA transformer for ML pipeline Author: Yanbo Liang Closes #7065 from yanboliang/spark-8664 and squashes the following commits: 4afae45 [Yanbo Liang] address comments e9effd7 [Yanbo Liang] Add PCA transformer --- .../org/apache/spark/ml/feature/PCA.scala | 130 ++++++++++++++++++ .../org/apache/spark/mllib/feature/PCA.scala | 2 +- .../apache/spark/ml/feature/PCASuite.scala | 64 +++++++++ 3 files changed, 195 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala new file mode 100644 index 0000000000000..2d3bb680cf309 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[PCA]] and [[PCAModel]]. + */ +private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol { + + /** + * The number of principal components. + * @group param + */ + final val k: IntParam = new IntParam(this, "k", "the number of principal components") + + /** @group getParam */ + def getK: Int = $(k) + +} + +/** + * :: Experimental :: + * PCA trains a model to project vectors to a low-dimensional space using PCA. + */ +@Experimental +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { + + def this() = this(Identifiable.randomUID("pca")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setK(value: Int): this.type = set(k, value) + + /** + * Computes a [[PCAModel]] that contains the principal components of the input vectors. + */ + override def fit(dataset: DataFrame): PCAModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} + val pca = new feature.PCA(k = $(k)) + val pcaModel = pca.fit(input) + copyValues(new PCAModel(uid, pcaModel).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): PCA = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[PCA]]. + */ +@Experimental +class PCAModel private[ml] ( + override val uid: String, + pcaModel: feature.PCAModel) + extends Model[PCAModel] with PCAParams { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Transform a vector by computed Principal Components. + * NOTE: Vectors to be transformed must be the same length + * as the source vectors given to [[PCA.fit()]]. + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val pcaOp = udf { pcaModel.transform _ } + dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): PCAModel = { + val copied = new PCAModel(uid, pcaModel) + copyValues(copied, extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 4e01e402b4283..2a66263d8b7d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -68,7 +68,7 @@ class PCA(val k: Int) { * @param k number of principal components. * @param pc a principal components Matrix. Each column is one principal component. */ -class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { +class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { /** * Transform a vector by computed Principal Components. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala new file mode 100644 index 0000000000000..d0ae36b28c7a9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} +import org.apache.spark.sql.Row + +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new PCA) + val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] + val model = new PCAModel("pca", new OldPCAModel(2, mat)) + ParamsSuite.checkParams(model) + } + + test("pca") { + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + + val dataRDD = sc.parallelize(data, 2) + + val mat = new RowMatrix(dataRDD) + val pc = mat.computePrincipalComponents(3) + val expected = mat.multiply(pc).rows + + val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") + + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pca_features") + .setK(3) + .fit(df) + + pca.transform(df).select("pca_features", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } +} From b8e5bb6fc1553256e950fdad9cb5acc6b296816e Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Tue, 30 Jun 2015 12:24:47 -0700 Subject: [PATCH 052/274] [SPARK-8628] [SQL] Race condition in AbstractSparkSQLParser.parse Made lexical iniatialization as lazy val Author: Vinod K C Closes #7015 from vinodkc/handle_lexical_initialize_schronization and squashes the following commits: b6d1c74 [Vinod K C] Avoided repeated lexical initialization 5863cf7 [Vinod K C] Removed space e27c66c [Vinod K C] Avoid reinitialization of lexical in parse method ef4f60f [Vinod K C] Reverted import order e9fc49a [Vinod K C] handle synchronization in SqlLexical.initialize --- .../apache/spark/sql/catalyst/AbstractSparkSQLParser.scala | 6 ++++-- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index ef7b3ad9432cf..d494ae7b71d16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions import scala.util.parsing.combinator.lexical.StdLexical import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.{PackratParsers, RegexParsers} +import scala.util.parsing.combinator.PackratParsers import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ @@ -30,12 +30,14 @@ private[sql] abstract class AbstractSparkSQLParser def parse(input: String): LogicalPlan = { // Initialize the Keywords. - lexical.initialize(reservedWords) + initLexical phrase(start)(new lexical.Scanner(input)) match { case Success(plan, _) => plan case failureOrError => sys.error(failureOrError.toString) } } + /* One time initialization of lexical.This avoid reinitialization of lexical in parse method */ + protected lazy val initLexical: Unit = lexical.initialize(reservedWords) protected case class Keyword(str: String) { def normalize: String = lexical.normalizeKeyword(str) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 79f526e823cd4..8d02fbf4f92c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -40,7 +40,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { def parseExpression(input: String): Expression = { // Initialize the Keywords. - lexical.initialize(reservedWords) + initLexical phrase(projection)(new lexical.Scanner(input)) match { case Success(plan, _) => plan case failureOrError => sys.error(failureOrError.toString) From 74cc16dbc35e35fd5cd5542239dcb6e5e7f92d18 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 30 Jun 2015 12:31:33 -0700 Subject: [PATCH 053/274] [SPARK-8471] [ML] Discrete Cosine Transform Feature Transformer Implementation and tests for Discrete Cosine Transformer. Author: Feynman Liang Closes #6894 from feynmanliang/dct-features and squashes the following commits: 433dbc7 [Feynman Liang] Test refactoring 91e9636 [Feynman Liang] Style guide and test helper refactor b5ac19c [Feynman Liang] Use Vector types, add Java test 530983a [Feynman Liang] Tests for other numeric datatypes 195d7aa [Feynman Liang] Implement support for arbitrary numeric types 95d4939 [Feynman Liang] Working DCT for 1D Doubles --- .../feature/DiscreteCosineTransformer.scala | 72 +++++++++++++++++ .../JavaDiscreteCosineTransformerSuite.java | 78 +++++++++++++++++++ .../DiscreteCosineTransformerSuite.scala | 73 +++++++++++++++++ 3 files changed, 223 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala new file mode 100644 index 0000000000000..a2f4d59f81c44 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import edu.emory.mathcs.jtransforms.dct._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.BooleanParam +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.types.DataType + +/** + * :: Experimental :: + * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero + * padding is performed on the input vector. + * It returns a real vector of the same length representing the DCT. The return vector is scaled + * such that the transform matrix is unitary (aka scaled DCT-II). + * + * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]]. + */ +@Experimental +class DiscreteCosineTransformer(override val uid: String) + extends UnaryTransformer[Vector, Vector, DiscreteCosineTransformer] { + + def this() = this(Identifiable.randomUID("dct")) + + /** + * Indicates whether to perform the inverse DCT (true) or forward DCT (false). + * Default: false + * @group param + */ + def inverse: BooleanParam = new BooleanParam( + this, "inverse", "Set transformer to perform inverse DCT") + + /** @group setParam */ + def setInverse(value: Boolean): this.type = set(inverse, value) + + /** @group getParam */ + def getInverse: Boolean = $(inverse) + + setDefault(inverse -> false) + + override protected def createTransformFunc: Vector => Vector = { vec => + val result = vec.toArray + val jTransformer = new DoubleDCT_1D(result.length) + if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true) + Vectors.dense(result) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + } + + override protected def outputDataType: DataType = new VectorUDT +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java new file mode 100644 index 0000000000000..28bc5f65e0532 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaDiscreteCosineTransformerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaDiscreteCosineTransformerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void javaCompatibilityTest() { + double[] input = new double[] {1D, 2D, 3D, 4D}; + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.dense(input)) + )); + DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{ + new StructField("vec", (new VectorUDT()), false, Metadata.empty()) + })); + + double[] expectedResult = input.clone(); + (new DoubleDCT_1D(input.length)).forward(expectedResult, true); + + DiscreteCosineTransformer DCT = new DiscreteCosineTransformer() + .setInputCol("vec") + .setOutputCol("resultVec"); + + Row[] result = DCT.transform(dataset).select("resultVec").collect(); + Vector resultVec = result[0].getAs("resultVec"); + + Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala new file mode 100644 index 0000000000000..ed0fc11f78f69 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +@BeanInfo +case class DCTTestData(vec: Vector, wantedVec: Vector) + +class DiscreteCosineTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("forward transform of discrete cosine matches jTransforms result") { + val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) + val inverse = false + + testDCT(data, inverse) + } + + test("inverse transform of discrete cosine matches jTransforms result") { + val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) + val inverse = true + + testDCT(data, inverse) + } + + private def testDCT(data: Vector, inverse: Boolean): Unit = { + val expectedResultBuffer = data.toArray.clone() + if (inverse) { + (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true) + } else { + (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true) + } + val expectedResult = Vectors.dense(expectedResultBuffer) + + val dataset = sqlContext.createDataFrame(Seq( + DCTTestData(data, expectedResult) + )) + + val transformer = new DiscreteCosineTransformer() + .setInputCol("vec") + .setOutputCol("resultVec") + .setInverse(inverse) + + transformer.transform(dataset) + .select("resultVec", "wantedVec") + .collect() + .foreach { case Row(resultVec: Vector, wantedVec: Vector) => + assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) + } + } +} From 61d7b533dd50bfac2162b4edcea94724bbd8fcb1 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 30 Jun 2015 12:44:43 -0700 Subject: [PATCH 054/274] [SPARK-7514] [MLLIB] Add MinMaxScaler to feature transformation jira: https://issues.apache.org/jira/browse/SPARK-7514 Add a popular scaling method to feature component, which is commonly known as min-max normalization or Rescaling. Core function is, Normalized(x) = (x - min) / (max - min) * scale + newBase where `newBase` and `scale` are parameters (type Double) of the `VectorTransformer`. `newBase` is the new minimum number for the features, and `scale` controls the ranges after transformation. This is a little complicated than the basic MinMax normalization, yet it provides flexibility so that users can control the range more specifically. like [0.1, 0.9] in some NN application. For case that `max == min`, 0.5 is used as the raw value. (0.5 * scale + newBase) I'll add UT once the design got settled ( and this is not considered as too naive) reference: http://en.wikipedia.org/wiki/Feature_scaling http://stn.spotfire.com/spotfire_client_help/index.htm#norm/norm_scale_between_0_and_1.htm Author: Yuhao Yang Closes #6039 from hhbyyh/minMaxNorm and squashes the following commits: f942e9f [Yuhao Yang] add todo for metadata 8b37bbc [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 4894dbc [Yuhao Yang] add copy fa2989f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 29db415 [Yuhao Yang] add clue and minor adjustment 5b8f7cc [Yuhao Yang] style fix 9b133d0 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 22f20f2 [Yuhao Yang] style change and bug fix 747c9bb [Yuhao Yang] add ut and remove mllib version a5ba0aa [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 585cc07 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 1c6dcb1 [Yuhao Yang] minor change 0f1bc80 [Yuhao Yang] add MinMaxScaler to ml 8e7436e [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 3663165 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 1247c27 [Yuhao Yang] some comments improvement d285a19 [Yuhao Yang] initial checkin for minMaxNorm --- .../spark/ml/feature/MinMaxScaler.scala | 170 ++++++++++++++++++ .../spark/ml/feature/MinMaxScalerSuite.scala | 68 +++++++ 2 files changed, 238 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala new file mode 100644 index 0000000000000..b30adf3df48d2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[MinMaxScaler]] and [[MinMaxScalerModel]]. + */ +private[feature] trait MinMaxScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * lower bound after transformation, shared by all features + * Default: 0.0 + * @group param + */ + val min: DoubleParam = new DoubleParam(this, "min", + "lower bound of the output feature range") + + /** + * upper bound after transformation, shared by all features + * Default: 1.0 + * @group param + */ + val max: DoubleParam = new DoubleParam(this, "max", + "upper bound of the output feature range") + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def validateParams(): Unit = { + require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") + } +} + +/** + * :: Experimental :: + * Rescale each feature individually to a common range [min, max] linearly using column summary + * statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + * feature E is calculated as, + * + * Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min + * + * For the case E_{max} == E_{min}, Rescaled(e_i) = 0.5 * (max + min) + * Note that since zero values will probably be transformed to non-zero values, output of the + * transformer will be DenseVector even for sparse input. + */ +@Experimental +class MinMaxScaler(override val uid: String) + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + + def this() = this(Identifiable.randomUID("minMaxScal")) + + setDefault(min -> 0.0, max -> 1.0) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + override def fit(dataset: DataFrame): MinMaxScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val summary = Statistics.colStats(input) + copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[MinMaxScaler]]. + * + * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). + */ +@Experimental +class MinMaxScalerModel private[ml] ( + override val uid: String, + val originalMin: Vector, + val originalMax: Vector) + extends Model[MinMaxScalerModel] with MinMaxScalerParams { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + + override def transform(dataset: DataFrame): DataFrame = { + val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray + val minArray = originalMin.toArray + + val reScale = udf { (vector: Vector) => + val scale = $(max) - $(min) + + // 0 in sparse vector will probably be rescaled to non-zero + val values = vector.toArray + val size = values.size + var i = 0 + while (i < size) { + val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5 + values(i) = raw * scale + $(min) + i += 1 + } + Vectors.dense(values) + } + + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScalerModel = { + val copied = new MinMaxScalerModel(uid, originalMin, originalMax) + copyValues(copied, extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala new file mode 100644 index 0000000000000..c452054bec92f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} + +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("MinMaxScaler fit basic case") { + val sqlContext = new SQLContext(sc) + + val data = Array( + Vectors.dense(1, 0, Long.MinValue), + Vectors.dense(2, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)), + Vectors.sparse(3, Array(0), Array(1.5))) + + val expected: Array[Vector] = Array( + Vectors.dense(-5, 0, -5), + Vectors.dense(0, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(5, 5)), + Vectors.sparse(3, Array(0), Array(-2.5))) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaled") + .setMin(-5) + .setMax(5) + + val model = scaler.fit(df) + model.transform(df).select("expected", "scaled").collect() + .foreach { case Row(vector1: Vector, vector2: Vector) => + assert(vector1.equals(vector2), "Transformed vector is different with expected.") + } + } + + test("MinMaxScaler arguments max must be larger than min") { + withClue("arguments max must be larger than min") { + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(10).setMax(0) + scaler.validateParams() + } + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(0).setMax(0) + scaler.validateParams() + } + } + } +} From 79f0b371a36560a009c1b0943c928adc5a1bdd8f Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 30 Jun 2015 13:56:59 -0700 Subject: [PATCH 055/274] [SPARK-8560] [UI] The Executors page will have negative if having resubmitted tasks when the ```taskEnd.reason``` is ```Resubmitted```, it shouldn't do statistics. Because this tasks has a ```SUCCESS``` taskEnd before. Author: xutingjun Closes #6950 from XuTingjun/pageError and squashes the following commits: af35dc3 [xutingjun] When taskEnd is Resubmitted, don't do statistics --- .../org/apache/spark/ui/exec/ExecutorsTab.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 39583af14390d..a88fc4c37d3c9 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.exec import scala.collection.mutable.HashMap -import org.apache.spark.{ExceptionFailure, SparkContext} +import org.apache.spark.{Resubmitted, ExceptionFailure, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageStatus, StorageStatusListener} @@ -92,15 +92,22 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val info = taskEnd.taskInfo if (info != null) { val eid = info.executorId - executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 - executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration taskEnd.reason match { + case Resubmitted => + // Note: For resubmitted tasks, we continue to use the metrics that belong to the + // first attempt of this task. This may not be 100% accurate because the first attempt + // could have failed half-way through. The correct fix would be to keep track of the + // metrics added by each attempt, but this is much more complicated. + return case e: ExceptionFailure => executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 case _ => executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 } + executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 + executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration + // Update shuffle read/write val metrics = taskEnd.taskMetrics if (metrics != null) { From 7dda0844e1eb6df7455af68592751806b3b92251 Mon Sep 17 00:00:00 2001 From: Joshi Date: Tue, 30 Jun 2015 14:00:35 -0700 Subject: [PATCH 056/274] [SPARK-2645] [CORE] Allow SparkEnv.stop() to be called multiple times without side effects. Fix for SparkContext stop behavior - Allow sc.stop() to be called multiple times without side effects. Author: Joshi Author: Rekha Joshi Closes #6973 from rekhajoshm/SPARK-2645 and squashes the following commits: 277043e [Joshi] Fix for SparkContext stop behavior 446b0a4 [Joshi] Fix for SparkContext stop behavior 2ce5760 [Joshi] Fix for SparkContext stop behavior c97839a [Joshi] Fix for SparkContext stop behavior 1aff39c [Joshi] Fix for SparkContext stop behavior 12f66b5 [Joshi] Fix for SparkContext stop behavior 72bb484 [Joshi] Fix for SparkContext stop behavior a5a7d7f [Joshi] Fix for SparkContext stop behavior 9193a0c [Joshi] Fix for SparkContext stop behavior 58dba70 [Joshi] SPARK-2645: Fix for SparkContext stop behavior 380c5b0 [Joshi] SPARK-2645: Fix for SparkContext stop behavior b566b66 [Joshi] SPARK-2645: Fix for SparkContext stop behavior 0be142d [Rekha Joshi] Merge pull request #3 from apache/master 106fd8e [Rekha Joshi] Merge pull request #2 from apache/master e3677c9 [Rekha Joshi] Merge pull request #1 from apache/master --- .../scala/org/apache/spark/SparkEnv.scala | 66 ++++++++++--------- .../org/apache/spark/SparkContextSuite.scala | 13 ++++ 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index b0665570e2681..1b133fbdfaf59 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -22,7 +22,6 @@ import java.net.Socket import akka.actor.ActorSystem -import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties @@ -90,39 +89,42 @@ class SparkEnv ( private var driverTmpDirToDelete: Option[String] = None private[spark] def stop() { - isStopped = true - pythonWorkers.foreach { case(key, worker) => worker.stop() } - Option(httpFileServer).foreach(_.stop()) - mapOutputTracker.stop() - shuffleManager.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - outputCommitCoordinator.stop() - rpcEnv.shutdown() - - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. - // actorSystem.awaitTermination() - - // Note that blockTransferService is stopped by BlockManager since it is started by it. - - // If we only stop sc, but the driver process still run as a services then we need to delete - // the tmp dir, if not, it will create too many tmp dirs. - // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the - // current working dir in executor which we do not need to delete. - driverTmpDirToDelete match { - case Some(path) => { - try { - Utils.deleteRecursively(new File(path)) - } catch { - case e: Exception => - logWarning(s"Exception while deleting Spark temp dir: $path", e) + + if (!isStopped) { + isStopped = true + pythonWorkers.values.foreach(_.stop()) + Option(httpFileServer).foreach(_.stop()) + mapOutputTracker.stop() + shuffleManager.stop() + broadcastManager.stop() + blockManager.stop() + blockManager.master.stop() + metricsSystem.stop() + outputCommitCoordinator.stop() + rpcEnv.shutdown() + + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut + // down, but let's call it anyway in case it gets fixed in a later release + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + // actorSystem.awaitTermination() + + // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the + // current working dir in executor which we do not need to delete. + driverTmpDirToDelete match { + case Some(path) => { + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor } - case None => // We just need to delete tmp dir created by driver, so do nothing on executor } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 6838b35ab4cc8..5c57940fa5f77 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.util.Utils import scala.concurrent.Await import scala.concurrent.duration.Duration +import org.scalatest.Matchers._ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { @@ -272,4 +273,16 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } } + + test("calling multiple sc.stop() must not throw any exception") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val cnt = sc.parallelize(1 to 4).count() + sc.cancelAllJobs() + sc.stop() + // call stop second time + sc.stop() + } + } + } From 4bb8375fc2c6aa8342df03c3617aa97e7d01de3f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 30 Jun 2015 14:01:52 -0700 Subject: [PATCH 057/274] [SPARK-8372] Do not show applications that haven't recorded their app ID yet. Showing these applications may lead to weird behavior in the History Server. For old logs, if the app ID is recorded later, you may end up with a duplicate entry. For new logs, the app might be listed with a ".inprogress" suffix. So ignore those, but still allow old applications that don't record app IDs at all (1.0 and 1.1) to be shown. Author: Marcelo Vanzin Author: Carson Wang Closes #7097 from vanzin/SPARK-8372 and squashes the following commits: a24eab2 [Marcelo Vanzin] Feedback. 112ae8f [Marcelo Vanzin] Merge branch 'master' into SPARK-8372 7b91b74 [Marcelo Vanzin] Handle logs generated by 1.0 and 1.1. 1eca3fe [Carson Wang] [SPARK-8372] History server shows incorrect information for application not started --- .../deploy/history/FsHistoryProvider.scala | 98 ++++++++++------ .../history/FsHistoryProviderSuite.scala | 109 +++++++++++++----- 2 files changed, 147 insertions(+), 60 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 5427a88f32ffd..2cc465e55fceb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -83,12 +83,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] - // Constants used to parse Spark 1.0.0 log directories. - private[history] val LOG_PREFIX = "EVENT_LOG_" - private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" - private[history] val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" - private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" - /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -146,7 +140,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] = { try { applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId).map { attempt => + appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => val replayBus = new ReplayListenerBus() val ui = { val conf = this.conf.clone() @@ -155,20 +149,20 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } - val appListener = new ApplicationEventListener() replayBus.addListener(appListener) val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - - ui.setAppName(s"${appInfo.name} ($appId)") - - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) - // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, - appListener.viewAcls.getOrElse("")) - ui + appInfo.map { info => + ui.setAppName(s"${info.name} ($appId)") + + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so they are properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, + appListener.viewAcls.getOrElse("")) + ui + } } } } catch { @@ -282,8 +276,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val newAttempts = logs.flatMap { fileStatus => try { val res = replay(fileStatus, bus) - logInfo(s"Application log ${res.logPath} loaded successfully.") - Some(res) + res match { + case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") + case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + + "The application may have not started.") + } + res } catch { case e: Exception => logError( @@ -429,9 +427,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replays the events in the specified log file and returns information about the associated - * application. + * application. Return `None` if the application ID cannot be located. */ - private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationAttemptInfo = { + private def replay( + eventLog: FileStatus, + bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") val logInput = @@ -445,16 +445,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted) - new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appListener.appId.getOrElse(logPath.getName()), - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog).get, - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted) + + // Without an app ID, new logs will render incorrectly in the listing page, so do not list or + // try to show their UI. Some old versions of Spark generate logs without an app ID, so let + // logs generated by those versions go through. + if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + Some(new FsApplicationAttemptInfo( + logPath.getName(), + appListener.appName.getOrElse(NOT_STARTED), + appListener.appId.getOrElse(logPath.getName()), + appListener.appAttemptId, + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + getModificationTime(eventLog).get, + appListener.sparkUser.getOrElse(NOT_STARTED), + appCompleted)) + } else { + None + } } finally { logInput.close() } @@ -529,10 +537,34 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + /** + * Returns whether the version of Spark that generated logs records app IDs. App IDs were added + * in Spark 1.1. + */ + private def sparkVersionHasAppId(entry: FileStatus): Boolean = { + if (isLegacyLogDirectory(entry)) { + fs.listStatus(entry.getPath()) + .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } + .map { status => + val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) + version != "1.0" && version != "1.1" + } + .getOrElse(true) + } else { + true + } + } + } -private object FsHistoryProvider { +private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + // Constants used to parse Spark 1.0.0 log directories. + val LOG_PREFIX = "EVENT_LOG_" + val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" + val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" + val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } private class FsApplicationAttemptInfo( diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 09075eeb539aa..2a62450bcdbad 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -39,6 +39,8 @@ import org.apache.spark.util.{JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { + import FsHistoryProvider._ + private var testDir: File = null before { @@ -67,7 +69,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Write a new-style application log. val newAppComplete = newLogFile("new1", None, inProgress = false) writeFile(newAppComplete, true, None, - SparkListenerApplicationStart("new-app-complete", None, 1L, "test", None), + SparkListenerApplicationStart(newAppComplete.getName(), Some("new-app-complete"), 1L, "test", + None), SparkListenerApplicationEnd(5L) ) @@ -75,35 +78,30 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val newAppCompressedComplete = newLogFile("new1compressed", None, inProgress = false, Some("lzf")) writeFile(newAppCompressedComplete, true, None, - SparkListenerApplicationStart("new-app-compressed-complete", None, 1L, "test", None), + SparkListenerApplicationStart(newAppCompressedComplete.getName(), Some("new-complete-lzf"), + 1L, "test", None), SparkListenerApplicationEnd(4L)) // Write an unfinished app, new-style. val newAppIncomplete = newLogFile("new2", None, inProgress = true) writeFile(newAppIncomplete, true, None, - SparkListenerApplicationStart("new-app-incomplete", None, 1L, "test", None) + SparkListenerApplicationStart(newAppIncomplete.getName(), Some("new-incomplete"), 1L, "test", + None) ) // Write an old-style application log. - val oldAppComplete = new File(testDir, "old1") - oldAppComplete.mkdir() - createEmptyFile(new File(oldAppComplete, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldAppComplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("old-app-complete", None, 2L, "test", None), + val oldAppComplete = writeOldLog("old1", "1.0", None, true, + SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(oldAppComplete, provider.APPLICATION_COMPLETE)) // Check for logs so that we force the older unfinished app to be loaded, to make // sure unfinished apps are also sorted correctly. provider.checkForLogs() // Write an unfinished app, old-style. - val oldAppIncomplete = new File(testDir, "old2") - oldAppIncomplete.mkdir() - createEmptyFile(new File(oldAppIncomplete, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldAppIncomplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("old-app-incomplete", None, 2L, "test", None) + val oldAppIncomplete = writeOldLog("old2", "1.0", None, false, + SparkListenerApplicationStart("old2", None, 2L, "test", None) ) // Force a reload of data from the log directory, and check that both logs are loaded. @@ -124,16 +122,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed))) } - list(0) should be (makeAppInfo(newAppComplete.getName(), "new-app-complete", 1L, 5L, + list(0) should be (makeAppInfo("new-app-complete", newAppComplete.getName(), 1L, 5L, newAppComplete.lastModified(), "test", true)) - list(1) should be (makeAppInfo(newAppCompressedComplete.getName(), - "new-app-compressed-complete", 1L, 4L, newAppCompressedComplete.lastModified(), "test", - true)) - list(2) should be (makeAppInfo(oldAppComplete.getName(), "old-app-complete", 2L, 3L, + list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), + 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) + list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L, oldAppComplete.lastModified(), "test", true)) - list(3) should be (makeAppInfo(oldAppIncomplete.getName(), "old-app-incomplete", 2L, -1L, - oldAppIncomplete.lastModified(), "test", false)) - list(4) should be (makeAppInfo(newAppIncomplete.getName(), "new-app-incomplete", 1L, -1L, + list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L, + -1L, oldAppIncomplete.lastModified(), "test", false)) + list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. @@ -155,12 +152,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null val logDir = new File(testDir, codecName) logDir.mkdir() - createEmptyFile(new File(logDir, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(logDir, provider.LOG_PREFIX + "1"), false, Option(codec), + createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec), SparkListenerApplicationStart("app2", None, 2L, "test", None), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(logDir, provider.COMPRESSION_CODEC_PREFIX + codecName)) + createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName)) val logPath = new Path(logDir.getAbsolutePath()) try { @@ -180,12 +177,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("SPARK-3697: ignore directories that cannot be read.") { val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, - SparkListenerApplicationStart("app1-1", None, 1L, "test", None), + SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), SparkListenerApplicationEnd(2L) ) val logFile2 = newLogFile("new2", None, inProgress = false) writeFile(logFile2, true, None, - SparkListenerApplicationStart("app1-2", None, 1L, "test", None), + SparkListenerApplicationStart("app1-2", Some("app1-2"), 1L, "test", None), SparkListenerApplicationEnd(2L) ) logFile2.setReadable(false, false) @@ -218,6 +215,18 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("Parse logs that application is not started") { + val provider = new FsHistoryProvider((createTestConf())) + + val logFile1 = newLogFile("app1", None, inProgress = true) + writeFile(logFile1, true, None, + SparkListenerLogStart("1.4") + ) + updateAndCheck(provider) { list => + list.size should be (0) + } + } + test("SPARK-5582: empty log directory") { val provider = new FsHistoryProvider(createTestConf()) @@ -373,6 +382,33 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("SPARK-8372: new logs with no app ID are ignored") { + val provider = new FsHistoryProvider(createTestConf()) + + // Write a new log file without an app id, to make sure it's ignored. + val logFile1 = newLogFile("app1", None, inProgress = true) + writeFile(logFile1, true, None, + SparkListenerLogStart("1.4") + ) + + // Write a 1.2 log file with no start event (= no app id), it should be ignored. + writeOldLog("v12Log", "1.2", None, false) + + // Write 1.0 and 1.1 logs, which don't have app ids. + writeOldLog("v11Log", "1.1", None, true, + SparkListenerApplicationStart("v11Log", None, 2L, "test", None), + SparkListenerApplicationEnd(3L)) + writeOldLog("v10Log", "1.0", None, true, + SparkListenerApplicationStart("v10Log", None, 2L, "test", None), + SparkListenerApplicationEnd(4L)) + + updateAndCheck(provider) { list => + list.size should be (2) + list(0).id should be ("v10Log") + list(1).id should be ("v11Log") + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -412,4 +448,23 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) } + private def writeOldLog( + fname: String, + sparkVersion: String, + codec: Option[CompressionCodec], + completed: Boolean, + events: SparkListenerEvent*): File = { + val log = new File(testDir, fname) + log.mkdir() + + val oldEventLog = new File(log, LOG_PREFIX + "1") + createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion)) + writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*) + if (completed) { + createEmptyFile(new File(log, APPLICATION_COMPLETE)) + } + + log + } + } From 3ba23ffd377d12383d923d1550ac8e2b916090fc Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 30 Jun 2015 14:02:50 -0700 Subject: [PATCH 058/274] [SPARK-8736] [ML] GBTRegressor should not threshold prediction Changed GBTRegressor so it does NOT threshold the prediction. Added test which fails with bug but works after fix. CC: feynmanliang mengxr Author: Joseph K. Bradley Closes #7134 from jkbradley/gbrt-fix and squashes the following commits: 613b90e [Joseph K. Bradley] Changed GBTRegressor so it does NOT threshold the prediction --- .../spark/ml/regression/GBTRegressor.scala | 3 +-- .../ml/regression/GBTRegressorSuite.scala | 23 ++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 036e3acb07412..47c110d027d67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -172,8 +172,7 @@ final class GBTRegressionModel( // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) - if (prediction > 0.0) 1.0 else 0.0 + blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } override def copy(extra: ParamMap): GBTRegressionModel = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 98fb3d3f5f22c..9682edcd9ba84 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** @@ -67,6 +68,26 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("GBTRegressor behaves reasonably on toy data") { + val df = sqlContext.createDataFrame(Seq( + LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), + LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), + LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), + LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)), + LabeledPoint(9, Vectors.dense(1, 2, 6, 4)), + LabeledPoint(-4, Vectors.dense(6, 3, 2, 2)) + )) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(df) + val preds = model.transform(df) + val predictions = preds.select("prediction").map(_.getDouble(0)) + // Checks based on SPARK-8736 (to ensure it is not doing classification) + assert(predictions.max() > 2) + assert(predictions.min() < -1) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { From 8c898964f095fcb5bb1c9212e1e484b1eb55c296 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Jun 2015 14:06:50 -0700 Subject: [PATCH 059/274] [SPARK-8705] [WEBUI] Don't display rects when totalExecutionTime is 0 Because `System.currentTimeMillis()` is not accurate for tasks that only need several milliseconds, sometimes `totalExecutionTime` in `makeTimeline` will be 0. If `totalExecutionTime` is 0, there will the following error in the console. ![screen shot 2015-06-29 at 7 08 55 pm](https://cloud.githubusercontent.com/assets/1000778/8406776/5cd38e04-1e92-11e5-89f2-0c5134fe4b6b.png) This PR fixes it by using an empty svg tag when `totalExecutionTime` is 0. This is a screenshot for a task that its totalExecutionTime is 0 after fixing it. ![screen shot 2015-06-30 at 12 26 52 am](https://cloud.githubusercontent.com/assets/1000778/8412896/7b33b4be-1ebf-11e5-9100-d6d656af3747.png) Author: zsxwing Closes #7088 from zsxwing/SPARK-8705 and squashes the following commits: 9ee4ef5 [zsxwing] Address comments ef2ecfa [zsxwing] Don't display rects when totalExecutionTime is 0 --- .../org/apache/spark/ui/jobs/StagePage.scala | 52 +++++++++++-------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index e96bf49d0dd14..17e7519ddd01c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -570,6 +570,35 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val index = taskInfo.index val attempt = taskInfo.attempt + + val svgTag = + if (totalExecutionTime == 0) { + // SPARK-8705: Avoid invalid attribute error in JavaScript if execution time is 0 + """""" + } else { + s""" + | + | + | + | + | + | + |""".stripMargin + } val timelineObject = s""" |{ @@ -595,28 +624,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { |
Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)} |
Result Serialization Time: ${UIUtils.formatDuration(serializationTime)} |
Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}"> - | - | - | - | - | - | - | - |', + |$svgTag', |'start': new Date($launchTime), |'end': new Date($finishTime) |} From e72526227fdcf93b7a33375ef954746ac08753f5 Mon Sep 17 00:00:00 2001 From: lee19 Date: Tue, 30 Jun 2015 14:08:00 -0700 Subject: [PATCH 060/274] [SPARK-8563] [MLLIB] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k I'm sorry that I made https://github.com/apache/spark/pull/6949 closed by mistake. I pushed codes again. And, I added a test code. > There is a bug that `U.numCols() = self.nCols` in `IndexedRowMatrix.computeSVD()` It should have been `U.numCols() = k = svd.U.numCols()` > ``` self = U * sigma * V.transpose (m x n) = (m x n) * (k x k) * (k x n) //ASIS --> (m x n) = (m x k) * (k x k) * (k x n) //TOBE ``` Author: lee19 Closes #6953 from lee19/MLlibBugfix and squashes the following commits: c1812a0 [lee19] [SPARK-8563] [MLlib] Used nRows instead of numRows() to reduce a burden. 4b9803b [lee19] [SPARK-8563] [MLlib] Fixed a build error. c2ccd89 [lee19] Added a unit test that validates matrix sizes of svd for [SPARK-8563][MLlib] 8373424 [lee19] [SPARK-8563][MLlib] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k --- .../mllib/linalg/distributed/IndexedRowMatrix.scala | 2 +- .../linalg/distributed/IndexedRowMatrixSuite.scala | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 3be530fa07537..1c33b43ea7a8a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -146,7 +146,7 @@ class IndexedRowMatrix( val indexedRows = indices.zip(svd.U.rows).map { case (i, v) => IndexedRow(i, v) } - new IndexedRowMatrix(indexedRows, nRows, nCols) + new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt) } else { null } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 4a7b99a976f0a..0ecb7a221a503 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -135,6 +135,17 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(closeToZero(U * brzDiag(s) * V.t - localA)) } + test("validate matrix sizes of svd") { + val k = 2 + val A = new IndexedRowMatrix(indexedRows) + val svd = A.computeSVD(k, computeU = true) + assert(svd.U.numRows() === m) + assert(svd.U.numCols() === k) + assert(svd.s.size === k) + assert(svd.V.numRows === n) + assert(svd.V.numCols === k) + } + test("validate k in svd") { val A = new IndexedRowMatrix(indexedRows) intercept[IllegalArgumentException] { From d2495f7cc7d7caaa50d122d2969ddb693e6ecebd Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 30 Jun 2015 14:09:29 -0700 Subject: [PATCH 061/274] [SPARK-8739] [WEB UI] [WINDOWS] A illegal character `\r` can be contained in StagePage. This issue was reported by saurfang. Thanks! There is a following code in StagePage.scala. ``` |width="$serializationTimeProportion%"> |', |'start': new Date($launchTime), |'end': new Date($finishTime) |} |""".stripMargin.replaceAll("\n", " ") ``` The last `replaceAll("\n", "")` doesn't work when we checkout and build source code on Windows and deploy on Linux. It's because when we checkout the source code on Windows, new-line-code is replaced with `"\r\n"` and `replaceAll("\n", "")` replaces only `"\n"`. Author: Kousuke Saruta Closes #7133 from sarutak/SPARK-8739 and squashes the following commits: 17fb044 [Kousuke Saruta] Fixed a new-line-code issue --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 17e7519ddd01c..60e3c6343122c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -628,7 +628,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { |'start': new Date($launchTime), |'end': new Date($finishTime) |} - |""".stripMargin.replaceAll("\n", " ") + |""".stripMargin.replaceAll("""[\r\n]+""", " ") timelineObject }.mkString("[", ",", "]") From 58ee2a2e47948a895e557fbcabbeadb31f0a1022 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Jun 2015 16:17:46 -0700 Subject: [PATCH 062/274] [SPARK-8738] [SQL] [PYSPARK] capture SQL AnalysisException in Python API Capture the AnalysisException in SQL, hide the long java stack trace, only show the error message. cc rxin Author: Davies Liu Closes #7135 from davies/ananylis and squashes the following commits: dad7ae7 [Davies Liu] add comment ec0c0e8 [Davies Liu] Update utils.py cdd7edd [Davies Liu] add doc 7b044c2 [Davies Liu] fix python 3 f84d3bd [Davies Liu] capture SQL AnalysisException in Python API --- python/pyspark/rdd.py | 3 +- python/pyspark/sql/context.py | 2 ++ python/pyspark/sql/tests.py | 7 +++++ python/pyspark/sql/utils.py | 54 +++++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 python/pyspark/sql/utils.py diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cb20bc8b54027..79dafb0a4ef27 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -126,11 +126,12 @@ def _load_from_socket(port, serializer): # On most of IPv6-ready systems, IPv6 will take precedence. for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) try: - sock = socket.socket(af, socktype, proto) sock.settimeout(3) sock.connect(sa) except socket.error: + sock.close() sock = None continue break diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4dda3b430cfbf..4bf232111c496 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -33,6 +33,7 @@ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.utils import install_exception_handler try: import pandas @@ -96,6 +97,7 @@ def __init__(self, sparkContext, sqlContext=None): self._jvm = self._sc._jvm self._scala_SQLContext = sqlContext _monkey_patch_RDD(self) + install_exception_handler() @property def _ssql_ctx(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 34f397d0ffef0..5af2ce09bc122 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -46,6 +46,7 @@ from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction from pyspark.sql.window import Window +from pyspark.sql.utils import AnalysisException class UTC(datetime.tzinfo): @@ -847,6 +848,12 @@ def test_replace(self): self.assertEqual(row.age, 10) self.assertEqual(row.height, None) + def test_capture_analysis_exception(self): + self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) + # RuntimeException should not be captured + self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py new file mode 100644 index 0000000000000..8096802e7302f --- /dev/null +++ b/python/pyspark/sql/utils.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import py4j + + +class AnalysisException(Exception): + """ + Failed to analyze a SQL query plan. + """ + + +def capture_sql_exception(f): + def deco(*a, **kw): + try: + return f(*a, **kw) + except py4j.protocol.Py4JJavaError as e: + cls, msg = e.java_exception.toString().split(': ', 1) + if cls == 'org.apache.spark.sql.AnalysisException': + raise AnalysisException(msg) + raise + return deco + + +def install_exception_handler(): + """ + Hook an exception handler into Py4j, which could capture some SQL exceptions in Java. + + When calling Java API, it will call `get_return_value` to parse the returned object. + If any exception happened in JVM, the result will be Java exception object, it raise + py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that + could capture the Java exception and throw a Python one (with the same error message). + + It's idempotent, could be called multiple times. + """ + original = py4j.protocol.get_return_value + # The original `get_return_value` is not patched, it's idempotent. + patched = capture_sql_exception(original) + # only patch the one used in in py4j.java_gateway (call Java API) + py4j.java_gateway.get_return_value = patched From 8d23587f1d285e93983b4b7d1decea01c2fe2e9e Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 30 Jun 2015 16:28:25 -0700 Subject: [PATCH 063/274] [SPARK-7739] [MLLIB] Improve ChiSqSelector example code in user guide Author: sethah Closes #7029 from sethah/working_on_SPARK-7739 and squashes the following commits: ef96916 [sethah] Fixing some style issues efea1f8 [sethah] adding clarification to ChiSqSelector example --- docs/mllib-feature-extraction.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 83e937635a55b..a69e41e2a1936 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -384,7 +384,7 @@ data2 = labels.zip(normalizer2.transform(features)) [Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. ### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. #### Model Fitting @@ -405,7 +405,7 @@ Note that the user can also construct a `ChiSqSelectorModel` by hand by providin #### Example -The following example shows the basic use of ChiSqSelector. +The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
@@ -419,10 +419,11 @@ import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category val discretizedData = data.map { lp => - LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor } ) ) } -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features val selector = new ChiSqSelector(50) // Create ChiSqSelector model (selecting features) val transformer = selector.fit(discretizedData) @@ -451,19 +452,20 @@ JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category JavaRDD discretizedData = points.map( new Function() { @Override public LabeledPoint call(LabeledPoint lp) { final double[] discretizedFeatures = new double[lp.features().size()]; for (int i = 0; i < lp.features().size(); ++i) { - discretizedFeatures[i] = lp.features().apply(i) / 16; + discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); } return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); } }); -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features ChiSqSelector selector = new ChiSqSelector(50); // Create ChiSqSelector model (selecting features) final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); From 8133125ca0b83985e0c2aa2a6ad477556867e412 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Jun 2015 16:54:51 -0700 Subject: [PATCH 064/274] [SPARK-8741] [SQL] Remove e and pi from DataFrame functions. Author: Reynold Xin Closes #7137 from rxin/SPARK-8741 and squashes the following commits: 32c7e75 [Reynold Xin] [SPARK-8741][SQL] Remove e and pi from DataFrame functions. --- .../scala/org/apache/spark/sql/functions.scala | 18 ------------------ .../spark/sql/DataFrameFunctionsSuite.scala | 8 -------- 2 files changed, 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6331fe61052ab..5767668dd339b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -989,15 +989,6 @@ object functions { */ def cosh(columnName: String): Column = cosh(Column(columnName)) - /** - * Returns the double value that is closer than any other to e, the base of the natural - * logarithms. - * - * @group math_funcs - * @since 1.5.0 - */ - def e(): Column = EulerNumber() - /** * Computes the exponential of the given value. * @@ -1191,15 +1182,6 @@ object functions { */ def log1p(columnName: String): Column = log1p(Column(columnName)) - /** - * Returns the double value that is closer than any other to pi, the ratio of the circumference - * of a circle to its diameter. - * - * @group math_funcs - * @since 1.5.0 - */ - def pi(): Column = Pi() - /** * Computes the logarithm of the given column in base 2. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 11a8767ead96c..7ae89bcb1b9cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -86,14 +86,6 @@ class DataFrameFunctionsSuite extends QueryTest { } test("constant functions") { - checkAnswer( - testData2.select(e()).limit(1), - Row(scala.math.E) - ) - checkAnswer( - testData2.select(pi()).limit(1), - Row(scala.math.Pi) - ) checkAnswer( ctx.sql("SELECT E()"), Row(scala.math.E) From ccdb05222a223187199183fd48e3a3313d536965 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Tue, 30 Jun 2015 16:59:44 -0700 Subject: [PATCH 065/274] [SPARK-8727] [SQL] Missing python api; md5, log2 Jira: https://issues.apache.org/jira/browse/SPARK-8727 Author: Tarek Auel Author: Tarek Auel Closes #7114 from tarekauel/missing-python and squashes the following commits: ef4c61b [Tarek Auel] [SPARK-8727] revert dataframe change 4029d4d [Tarek Auel] removed dataframe pi and e unit test 66f0d2b [Tarek Auel] removed pi and e from python api and dataframe api; added _to_java_column(col) for strlen 4d07318 [Tarek Auel] fixed python unit test 45f2bee [Tarek Auel] fixed result of pi and e c39f47b [Tarek Auel] add python api bd50a3a [Tarek Auel] add missing python functions --- python/pyspark/sql/functions.py | 65 ++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 45ecd826bd3bd..4e2be88e9e3b9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -39,12 +39,15 @@ 'coalesce', 'countDistinct', 'explode', + 'log2', + 'md5', 'monotonicallyIncreasingId', 'rand', 'randn', 'sha1', 'sha2', 'sparkPartitionId', + 'strlen', 'struct', 'udf', 'when'] @@ -320,6 +323,19 @@ def explode(col): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def md5(col): + """Calculates the MD5 digest and returns the value as a 32 character hex string. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.md5(_to_java_column(col)) + return Column(jc) + + @since(1.4) def monotonicallyIncreasingId(): """A column that generates monotonically increasing 64-bit integers. @@ -365,6 +381,19 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def sha2(col, numBits): @@ -383,19 +412,6 @@ def sha2(col, numBits): return Column(jc) -@ignore_unicode_prefix -@since(1.5) -def sha1(col): - """Returns the hex string result of SHA-1. - - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() - [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.sha1(_to_java_column(col)) - return Column(jc) - - @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. @@ -409,6 +425,18 @@ def sparkPartitionId(): return Column(sc._jvm.functions.sparkPartitionId()) +@ignore_unicode_prefix +@since(1.5) +def strlen(col): + """Calculates the length of a string expression. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect() + [Row(length=3)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.strlen(_to_java_column(col))) + + @ignore_unicode_prefix @since(1.4) def struct(*cols): @@ -471,6 +499,17 @@ def log(arg1, arg2=None): return Column(jc) +@since(1.5) +def log2(col): + """Returns the base-2 logarithm of the argument. + + >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() + [Row(log2=2.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.log2(_to_java_column(col))) + + @since(1.4) def lag(col, count=1, default=None): """ From 3bee0f1466ddd69f26e95297b5e0d2398b6c6268 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Jun 2015 17:39:55 -0700 Subject: [PATCH 066/274] [SPARK-6602][Core] Update Master, Worker, Client, AppClient and related classes to use RpcEndpoint This PR updates the rest Actors in core to RpcEndpoint. Because there is no `ActorSelection` in RpcEnv, I changes the logic of `registerWithMaster` in Worker and AppClient to avoid blocking the message loop. These changes need to be reviewed carefully. Author: zsxwing Closes #5392 from zsxwing/rpc-rewrite-part3 and squashes the following commits: 2de7bed [zsxwing] Merge branch 'master' into rpc-rewrite-part3 f12d943 [zsxwing] Address comments 9137b82 [zsxwing] Fix the code style e734c71 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 2d24fb5 [zsxwing] Fix the code style 5a82374 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 fa47110 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 72304f0 [zsxwing] Update the error strategy for AkkaRpcEnv e56cb16 [zsxwing] Always send failure back to the sender a7b86e6 [zsxwing] Use JFuture for java.util.concurrent.Future aa34b9b [zsxwing] Fix the code style bd541e7 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 25a84d8 [zsxwing] Use ThreadUtils 060ff31 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 dbfc916 [zsxwing] Improve the docs and comments 837927e [zsxwing] Merge branch 'master' into rpc-rewrite-part3 5c27f97 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 fadbb9e [zsxwing] Fix the code style 6637e3c [zsxwing] Merge remote-tracking branch 'origin/master' into rpc-rewrite-part3 7fdee0e [zsxwing] Fix the return type to ExecutorService and ScheduledExecutorService e8ad0a5 [zsxwing] Fix the code style 6b2a104 [zsxwing] Log error and use SparkExitCode.UNCAUGHT_EXCEPTION exit code fbf3194 [zsxwing] Add Utils.newDaemonSingleThreadExecutor and newDaemonSingleThreadScheduledExecutor b776817 [zsxwing] Update Master, Worker, Client, AppClient and related classes to use RpcEndpoint --- .../org/apache/spark/deploy/Client.scala | 156 ++++--- .../apache/spark/deploy/DeployMessage.scala | 22 +- .../spark/deploy/LocalSparkCluster.scala | 26 +- .../spark/deploy/client/AppClient.scala | 199 +++++---- .../spark/deploy/client/TestClient.scala | 10 +- .../spark/deploy/master/ApplicationInfo.scala | 5 +- .../apache/spark/deploy/master/Master.scala | 392 +++++++++--------- .../spark/deploy/master/MasterMessages.scala | 2 +- .../spark/deploy/master/WorkerInfo.scala | 6 +- .../master/ZooKeeperLeaderElectionAgent.scala | 3 - .../deploy/master/ui/ApplicationPage.scala | 9 +- .../spark/deploy/master/ui/MasterPage.scala | 14 +- .../spark/deploy/master/ui/MasterWebUI.scala | 4 +- .../deploy/rest/StandaloneRestServer.scala | 35 +- .../spark/deploy/worker/DriverRunner.scala | 6 +- .../spark/deploy/worker/ExecutorRunner.scala | 8 +- .../apache/spark/deploy/worker/Worker.scala | 318 +++++++++----- .../spark/deploy/worker/WorkerWatcher.scala | 1 - .../spark/deploy/worker/ui/WorkerPage.scala | 11 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 2 + .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 8 +- .../cluster/SparkDeploySchedulerBackend.scala | 2 +- .../spark/deploy/master/MasterSuite.scala | 56 +-- .../rest/StandaloneRestSubmitSuite.scala | 54 +-- .../deploy/worker/WorkerWatcherSuite.scala | 15 +- .../apache/spark/rpc/RpcAddressSuite.scala | 55 +++ .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 20 +- 27 files changed, 806 insertions(+), 633 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 848b62f9de71b..71f7e2129116f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -18,17 +18,17 @@ package org.apache.spark.deploy import scala.collection.mutable.HashSet -import scala.concurrent._ +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag +import scala.util.{Failure, Success} -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils} /** * Proxy that relays messages to the driver. @@ -36,20 +36,30 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} * We currently don't support retry if submission fails. In HA mode, client will submit request to * all masters and see which one could handle it. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { - - private val masterActors = driverArgs.masters.map { m => - context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system))) - } - private val lostMasters = new HashSet[Address] - private var activeMasterActor: ActorSelection = null - - val timeout = RpcUtils.askTimeout(conf) - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - +private class ClientEndpoint( + override val rpcEnv: RpcEnv, + driverArgs: ClientArguments, + masterEndpoints: Seq[RpcEndpointRef], + conf: SparkConf) + extends ThreadSafeRpcEndpoint with Logging { + + // A scheduled executor used to send messages at the specified time. + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message") + // Used to provide the implicit parameter of `Future` methods. + private val forwardMessageExecutionContext = + ExecutionContext.fromExecutor(forwardMessageThread, + t => t match { + case ie: InterruptedException => // Exit normally + case e: Throwable => + logError(e.getMessage, e) + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + }) + + private val lostMasters = new HashSet[RpcAddress] + private var activeMasterEndpoint: RpcEndpointRef = null + + override def onStart(): Unit = { driverArgs.cmd match { case "launch" => // TODO: We could add an env variable here and intercept it in `sc.addJar` that would @@ -82,29 +92,37 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.cores, driverArgs.supervise, command) - - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestSubmitDriver(driverDescription) - } + ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestKillDriver(driverId) - } + ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + } + } + + /** + * Send the message to master and forward the reply to self asynchronously. + */ + private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + for (masterEndpoint <- masterEndpoints) { + masterEndpoint.ask[T](message).onComplete { + case Success(v) => self.send(v) + case Failure(e) => + logWarning(s"Error sending messages to master $masterEndpoint", e) + }(forwardMessageExecutionContext) } } /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { + // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread + // is fine. println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") - val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout) - .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) + val statusResponse = + activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => println(s"ERROR: Cluster master did not recognize $driverId") @@ -127,50 +145,62 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { - case SubmitDriverResponse(success, driverId, message) => + case SubmitDriverResponse(master, success, driverId, message) => println(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId.get) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } - case KillDriverResponse(driverId, success, message) => + case KillDriverResponse(master, driverId, success, message) => println(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } + } - case DisassociatedEvent(_, remoteAddress, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master $remoteAddress.") - lostMasters += remoteAddress - // Note that this heuristic does not account for the fact that a Master can recover within - // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This - // is not currently a concern, however, because this client does not retry submissions. - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master $remoteAddress.") + lostMasters += remoteAddress + // Note that this heuristic does not account for the fact that a Master can recover within + // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This + // is not currently a concern, however, because this client does not retry submissions. + if (lostMasters.size >= masterEndpoints.size) { + println("No master is available, exiting.") + System.exit(-1) } + } + } - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master ($remoteAddress).") - println(s"Cause was: $cause") - lostMasters += remoteAddress - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master ($remoteAddress).") + println(s"Cause was: $cause") + lostMasters += remoteAddress + if (lostMasters.size >= masterEndpoints.size) { + println("No master is available, exiting.") + System.exit(-1) } + } + } + + override def onError(cause: Throwable): Unit = { + println(s"Error processing messages, exiting.") + cause.printStackTrace() + System.exit(-1) + } + + override def onStop(): Unit = { + forwardMessageThread.shutdownNow() } } @@ -194,15 +224,13 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val rpcEnv = + RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - for (m <- driverArgs.masters) { - Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem)) - } - actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). + map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 9db6fd1ac4dbe..12727de9b4cf3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,11 +24,12 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable -/** Contains messages sent between Scheduler actor nodes. */ +/** Contains messages sent between Scheduler endpoint nodes. */ private[deploy] object DeployMessages { // Worker to Master @@ -37,6 +38,7 @@ private[deploy] object DeployMessages { id: String, host: String, port: Int, + worker: RpcEndpointRef, cores: Int, memory: Int, webUiPort: Int, @@ -63,11 +65,11 @@ private[deploy] object DeployMessages { case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], driverIds: Seq[String]) - case class Heartbeat(workerId: String) extends DeployMessage + case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage // Master to Worker - case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage case class RegisterWorkerFailed(message: String) extends DeployMessage @@ -92,13 +94,13 @@ private[deploy] object DeployMessages { // Worker internal - case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object WorkDirCleanup // Sent to Worker endpoint periodically for cleaning up app folders case object ReregisterWithMaster // used when a worker attempts to reconnect to a master // AppClient to Master - case class RegisterApplication(appDescription: ApplicationDescription) + case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) extends DeployMessage case class UnregisterApplication(appId: String) @@ -107,7 +109,7 @@ private[deploy] object DeployMessages { // Master to AppClient - case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage + case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { @@ -123,12 +125,14 @@ private[deploy] object DeployMessages { case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage - case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String) + case class SubmitDriverResponse( + master: RpcEndpointRef, success: Boolean, driverId: Option[String], message: String) extends DeployMessage case class RequestKillDriver(driverId: String) extends DeployMessage - case class KillDriverResponse(driverId: String, success: Boolean, message: String) + case class KillDriverResponse( + master: RpcEndpointRef, driverId: String, success: Boolean, message: String) extends DeployMessage case class RequestDriverStatus(driverId: String) extends DeployMessage @@ -142,7 +146,7 @@ private[deploy] object DeployMessages { // Master to Worker & AppClient - case class MasterChanged(masterUrl: String, masterWebUiUrl: String) + case class MasterChanged(master: RpcEndpointRef, masterWebUiUrl: String) // MasterWebUI To Master diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 0550f00a172ab..53356addf6edb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorSystem - +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master @@ -41,8 +40,8 @@ class LocalSparkCluster( extends Logging { private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() + private val masterRpcEnvs = ArrayBuffer[RpcEnv]() + private val workerRpcEnvs = ArrayBuffer[RpcEnv]() // exposed for testing var masterWebUIPort = -1 @@ -55,18 +54,17 @@ class LocalSparkCluster( .set("spark.shuffle.service.enabled", "false") /* Start the Master */ - val (masterSystem, masterPort, webUiPort, _) = - Master.startSystemAndActor(localHostname, 0, 0, _conf) + val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) masterWebUIPort = webUiPort - masterActorSystems += masterSystem - val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort + masterRpcEnvs += rpcEnv + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, + val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masters, null, Some(workerNum), _conf) - workerActorSystems += workerSystem + workerRpcEnvs += workerEnv } masters @@ -77,11 +75,11 @@ class LocalSparkCluster( // Stop the workers before the master so they don't get upset that it disconnected // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! // This is unfortunate, but for now we just comment it out. - workerActorSystems.foreach(_.shutdown()) + workerRpcEnvs.foreach(_.shutdown()) // workerActorSystems.foreach(_.awaitTermination()) - masterActorSystems.foreach(_.shutdown()) + masterRpcEnvs.foreach(_.shutdown()) // masterActorSystems.foreach(_.awaitTermination()) - masterActorSystems.clear() - workerActorSystems.clear() + masterRpcEnvs.clear() + workerRpcEnvs.clear() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 43c8a934c311a..79b251e7e62fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -17,20 +17,17 @@ package org.apache.spark.deploy.client -import java.util.concurrent.TimeoutException +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -40,98 +37,143 @@ import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} * @param masterUrls Each url should look like spark://host:port. */ private[spark] class AppClient( - actorSystem: ActorSystem, + rpcEnv: RpcEnv, masterUrls: Array[String], appDescription: ApplicationDescription, listener: AppClientListener, conf: SparkConf) extends Logging { - private val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) + private val masterRpcAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) - private val REGISTRATION_TIMEOUT = 20.seconds + private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - private var masterAddress: Address = null - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null private var appId: String = null - private var registered = false - private var activeMasterUrl: String = null + @volatile private var registered = false + + private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint + with Logging { + + private var master: Option[RpcEndpointRef] = None + // To avoid calling listener.disconnected() multiple times + private var alreadyDisconnected = false + @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times + @volatile private var registerMasterFutures: Array[JFuture[_]] = null + @volatile private var registrationRetryTimer: JScheduledFuture[_] = null + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) - private class ClientActor extends Actor with ActorLogReceive with Logging { - var master: ActorSelection = null - var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times - var alreadyDead = false // To avoid calling listener.dead() multiple times - var registrationRetryTimer: Option[Cancellable] = None + // A scheduled executor for scheduling the registration actions + private val registrationRetryThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + override def onStart(): Unit = { try { - registerWithMaster() + registerWithMaster(1) } catch { case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - context.stop(self) + stop() } } - def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterApplication(appDescription) + /** + * Register with all masters asynchronously and returns an array `Future`s for cancellation. + */ + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + for (masterAddress <- masterRpcAddresses) yield { + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = try { + if (registered) { + return + } + logInfo("Connecting to master " + masterAddress.toSparkURL + "...") + val masterRef = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterRef.send(RegisterApplication(appDescription, self)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + }) } } - def registerWithMaster() { - tryRegisterAllMasters() - import context.dispatcher - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { + /** + * Register with all masters asynchronously. It will call `registerWithMaster` every + * REGISTRATION_TIMEOUT_SECONDS seconds until exceeding REGISTRATION_RETRIES times. + * Once we connect to a master successfully, all scheduling work and Futures will be cancelled. + * + * nthRetry means this is the nth attempt to register with master. + */ + private def registerWithMaster(nthRetry: Int) { + registerMasterFutures = tryRegisterAllMasters() + registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { Utils.tryOrExit { - retries += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() + } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - tryRegisterAllMasters() + registerMasterFutures.foreach(_.cancel(true)) + registerWithMaster(nthRetry + 1) } } } - } + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) } - def changeMaster(url: String) { - // activeMasterUrl is a valid Spark url since we receive it from master. - activeMasterUrl = url - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem)) + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => logWarning(s"Drop $message because has not yet connected to master") + } } - private def isPossibleMaster(remoteUrl: Address) = { - masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) + private def isPossibleMaster(remoteAddress: RpcAddress): Boolean = { + masterRpcAddresses.contains(remoteAddress) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredApplication(appId_, masterUrl) => + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId_, masterRef) => + // FIXME How to handle the following cases? + // 1. A master receives multiple registrations and sends back multiple + // RegisteredApplications due to an unstable network. + // 2. Receive multiple RegisteredApplication from different masters because the master is + // changing. appId = appId_ registered = true - changeMaster(masterUrl) + master = Some(masterRef) listener.connected(appId) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - context.stop(self) + stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None) + // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not + // guaranteed), `ExecutorStateChanged` may be sent to a dead master. + sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -142,24 +184,32 @@ private[spark] class AppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + master = Some(masterRef) alreadyDisconnected = false - sender ! MasterChangeAcknowledged(appId) + masterRef.send(MasterChangeAcknowledged(appId)) + } - case DisassociatedEvent(_, address, _) if address == masterAddress => + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case StopAppClient => + markDead("Application has been stopped.") + sendToMaster(UnregisterApplication(appId)) + context.reply(true) + stop() + } + + override def onDisconnected(address: RpcAddress): Unit = { + if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() + } + } - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => + override def onNetworkError(cause: Throwable, address: RpcAddress): Unit = { + if (isPossibleMaster(address)) { logWarning(s"Could not connect to $address: $cause") - - case StopAppClient => - markDead("Application has been stopped.") - master ! UnregisterApplication(appId) - sender ! true - context.stop(self) + } } /** @@ -179,28 +229,31 @@ private[spark] class AppClient( } } - override def postStop() { - registrationRetryTimer.foreach(_.cancel()) + override def onStop(): Unit = { + if (registrationRetryTimer != null) { + registrationRetryTimer.cancel(true) + } + registrationRetryThread.shutdownNow() + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() } } def start() { // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) + endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) } def stop() { - if (actor != null) { + if (endpoint != null) { try { - val timeout = RpcUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + endpoint.askWithRetry[Boolean](StopAppClient) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - actor = null + endpoint = null } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 40835b9550586..1c79089303e3d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.client +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] object TestClient { @@ -46,13 +47,12 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0, - conf = conf, securityManager = new SecurityManager(conf)) + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener - val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 1620e95bea218..aa54ed9360f36 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -22,10 +22,9 @@ import java.util.Date import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorRef - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class ApplicationInfo( @@ -33,7 +32,7 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef, + val driver: RpcEndpointRef, defaultCores: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fccceb3ea528b..3e7c16722805e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,20 +21,18 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -47,23 +45,27 @@ import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[master] class Master( - host: String, - port: Int, + override val rpcEnv: RpcEnv, + address: RpcAddress, webUiPort: Int, val securityMgr: SecurityManager, val conf: SparkConf) - extends Actor with ActorLogReceive with Logging with LeaderElectable { + extends ThreadSafeRpcEndpoint with Logging with LeaderElectable { - import context.dispatcher // to use Akka's scheduler.schedule() + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + + // TODO Remove it once we don't use akka.serialization.Serialization + private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 + private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) @@ -75,10 +77,10 @@ private[master] class Master( val apps = new HashSet[ApplicationInfo] private val idToWorker = new HashMap[String, WorkerInfo] - private val addressToWorker = new HashMap[Address, WorkerInfo] + private val addressToWorker = new HashMap[RpcAddress, WorkerInfo] - private val actorToApp = new HashMap[ActorRef, ApplicationInfo] - private val addressToApp = new HashMap[Address, ApplicationInfo] + private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo] + private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 private val appIdToUI = new HashMap[String, SparkUI] @@ -89,21 +91,22 @@ private[master] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(address.host, "Expected hostname") private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, securityMgr) private val masterSource = new MasterSource(this) - private val webUi = new MasterWebUI(this, webUiPort) + // After onStart, webUi will be set + private var webUi: MasterWebUI = null private val masterPublicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host + if (envVar != null) envVar else address.host } - private val masterUrl = "spark://" + host + ":" + port + private val masterUrl = address.toSparkURL private var masterWebUiUrl: String = _ private var state = RecoveryState.STANDBY @@ -112,7 +115,9 @@ private[master] class Master( private var leaderElectionAgent: LeaderElectionAgent = _ - private var recoveryCompletionTask: Cancellable = _ + private var recoveryCompletionTask: ScheduledFuture[_] = _ + + private var checkForWorkerTimeOutTask: ScheduledFuture[_] = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app @@ -130,20 +135,23 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) + Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) } else { None } private val restServerBoundPort = restServer.map(_.start()) - override def preStart() { + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + webUi = new MasterWebUI(this, webUiPort) webUi.bind() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CheckForWorkerTimeOut) + } + }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() @@ -157,16 +165,16 @@ private[master] class Master( case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(context.system)) + .newInstance(conf, SerializationExtension(actorSystem)) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -176,18 +184,17 @@ private[master] class Master( leaderElectionAgent = leaderElectionAgent_ } - override def preRestart(reason: Throwable, message: Option[Any]) { - super.preRestart(reason, message) // calls postStop()! - logError("Master actor restarted due to exception", reason) - } - - override def postStop() { + override def onStop() { masterMetricsSystem.report() applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { - recoveryCompletionTask.cancel() + recoveryCompletionTask.cancel(true) } + if (checkForWorkerTimeOutTask != null) { + checkForWorkerTimeOutTask.cancel(true) + } + forwardMessageThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -197,14 +204,14 @@ private[master] class Master( } override def electedLeader() { - self ! ElectedLeader + self.send(ElectedLeader) } override def revokedLeadership() { - self ! RevokedLeadership + self.send(RevokedLeadership) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -215,8 +222,11 @@ private[master] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, - CompleteRecovery) + recoveryCompletionTask = forwardMessageThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CompleteRecovery) + } + }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } } @@ -227,111 +237,42 @@ private[master] class Master( System.exit(0) } - case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) => - { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { // ignore, don't send response } else if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") + workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - sender, workerUiPort, publicAddress) + workerRef, workerUiPort, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - sender ! RegisteredWorker(masterUrl, masterWebUiUrl) + workerRef.send(RegisteredWorker(self, masterWebUiUrl)) schedule() } else { - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address logWarning("Worker registration failed. Attempted to re-register worker at same " + "address: " + workerAddress) - sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress) + workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) } } } - case RequestSubmitDriver(description) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only accept driver submissions in ALIVE state." - sender ! SubmitDriverResponse(false, None, msg) - } else { - logInfo("Driver submitted " + description.command.mainClass) - val driver = createDriver(description) - persistenceEngine.addDriver(driver) - waitingDrivers += driver - drivers.add(driver) - schedule() - - // TODO: It might be good to instead have the submission client poll the master to determine - // the current status of the driver. For now it's simply "fire and forget". - - sender ! SubmitDriverResponse(true, Some(driver.id), - s"Driver successfully submitted as ${driver.id}") - } - } - - case RequestKillDriver(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - s"Can only kill drivers in ALIVE state." - sender ! KillDriverResponse(driverId, success = false, msg) - } else { - logInfo("Asked to kill driver " + driverId) - val driver = drivers.find(_.id == driverId) - driver match { - case Some(d) => - if (waitingDrivers.contains(d)) { - waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) - } else { - // We just notify the worker to kill the driver here. The final bookkeeping occurs - // on the return path when the worker submits a state change back to the master - // to notify it that the driver was successfully killed. - d.worker.foreach { w => - w.actor ! KillDriver(driverId) - } - } - // TODO: It would be nice for this to be a synchronous response - val msg = s"Kill request for $driverId submitted" - logInfo(msg) - sender ! KillDriverResponse(driverId, success = true, msg) - case None => - val msg = s"Driver $driverId has already finished or does not exist" - logWarning(msg) - sender ! KillDriverResponse(driverId, success = false, msg) - } - } - } - - case RequestDriverStatus(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only request driver status in ALIVE state." - sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))) - } else { - (drivers ++ completedDrivers).find(_.id == driverId) match { - case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) - case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) - } - } - } - - case RegisterApplication(description) => { + case RegisterApplication(description, driver) => { + // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response } else { logInfo("Registering app " + description.name) - val app = createApplication(description, sender) + val app = createApplication(description, driver) registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) persistenceEngine.addApplication(app) - sender ! RegisteredApplication(app.id, masterUrl) + driver.send(RegisteredApplication(app.id, self)) schedule() } } @@ -343,7 +284,7 @@ private[master] class Master( val appInfo = idToApp(appId) exec.state = state if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -384,7 +325,7 @@ private[master] class Master( } } - case Heartbeat(workerId) => { + case Heartbeat(workerId, worker) => { idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -392,7 +333,7 @@ private[master] class Master( if (workers.map(_.id).contains(workerId)) { logWarning(s"Got heartbeat from unregistered worker $workerId." + " Asking it to re-register.") - sender ! ReconnectWorker(masterUrl) + worker.send(ReconnectWorker(masterUrl)) } else { logWarning(s"Got heartbeat from unregistered worker $workerId." + " This worker was never registered, so ignoring the heartbeat.") @@ -444,30 +385,103 @@ private[master] class Master( logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) - case DisassociatedEvent(_, address, _) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + case CheckForWorkerTimeOut => { + timeOutDeadWorkers() } + } - case RequestMasterState => { - sender ! MasterStateResponse( - host, port, restServerBoundPort, - workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestSubmitDriver(description) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only accept driver submissions in ALIVE state." + context.reply(SubmitDriverResponse(self, false, None, msg)) + } else { + logInfo("Driver submitted " + description.command.mainClass) + val driver = createDriver(description) + persistenceEngine.addDriver(driver) + waitingDrivers += driver + drivers.add(driver) + schedule() + + // TODO: It might be good to instead have the submission client poll the master to determine + // the current status of the driver. For now it's simply "fire and forget". + + context.reply(SubmitDriverResponse(self, true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}")) + } } - case CheckForWorkerTimeOut => { - timeOutDeadWorkers() + case RequestKillDriver(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } else { + logInfo("Asked to kill driver " + driverId) + val driver = drivers.find(_.id == driverId) + driver match { + case Some(d) => + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + } else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.endpoint.send(KillDriver(driverId)) + } + } + // TODO: It would be nice for this to be a synchronous response + val msg = s"Kill request for $driverId submitted" + logInfo(msg) + context.reply(KillDriverResponse(self, driverId, success = true, msg)) + case None => + val msg = s"Driver $driverId has already finished or does not exist" + logWarning(msg) + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } + } + } + + case RequestDriverStatus(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only request driver status in ALIVE state." + context.reply( + DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg)))) + } else { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + context.reply(DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)) + case None => + context.reply(DriverStatusResponse(found = false, None, None, None, None)) + } + } + } + + case RequestMasterState => { + context.reply(MasterStateResponse( + address.host, address.port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, + drivers.toArray, completedDrivers.toArray, state)) } case BoundPortsRequest => { - sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) + context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) } } + override def onDisconnected(address: RpcAddress): Unit = { + // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + } + private def canCompleteRecovery = workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 @@ -479,7 +493,7 @@ private[master] class Master( try { registerApplication(app) app.state = ApplicationState.UNKNOWN - app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + app.driver.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("App " + app.id + " had exception on reconnect") } @@ -496,7 +510,7 @@ private[master] class Master( try { registerWorker(worker) worker.state = WorkerState.UNKNOWN - worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + worker.endpoint.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") } @@ -504,6 +518,7 @@ private[master] class Master( } private def completeRecovery() { + // TODO Why synchronized // Ensure "only-once" recovery semantics using a short synchronization period. synchronized { if (state != RecoveryState.RECOVERING) { return } @@ -623,10 +638,10 @@ private[master] class Master( private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(masterUrl, - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + worker.endpoint.send(LaunchExecutor(masterUrl, + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) + exec.application.driver.send(ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { @@ -638,7 +653,7 @@ private[master] class Master( workers -= w } - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address if (addressToWorker.contains(workerAddress)) { val oldWorker = addressToWorker(workerAddress) if (oldWorker.state == WorkerState.UNKNOWN) { @@ -661,11 +676,11 @@ private[master] class Master( logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id - addressToWorker -= worker.actor.path.address + addressToWorker -= worker.endpoint.address for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None)) exec.application.removeExecutor(exec) } for (driver <- worker.drivers.values) { @@ -687,14 +702,15 @@ private[master] class Master( schedule() } - private def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + private def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef): + ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } private def registerApplication(app: ApplicationInfo): Unit = { - val appAddress = app.driver.path.address + val appAddress = app.driver.address if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return @@ -703,7 +719,7 @@ private[master] class Master( applicationMetricsSystem.registerSource(app.appSource) apps += app idToApp(app.id) = app - actorToApp(app.driver) = app + endpointToApp(app.driver) = app addressToApp(appAddress) = app waitingApps += app } @@ -717,8 +733,8 @@ private[master] class Master( logInfo("Removing app " + app.id) apps -= app idToApp -= app.id - actorToApp -= app.driver - addressToApp -= app.driver.path.address + endpointToApp -= app.driver + addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { @@ -735,19 +751,19 @@ private[master] class Master( for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) + exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id)) exec.state = ExecutorState.KILLED } app.markFinished(state) if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) + app.driver.send(ApplicationRemoved(state.toString)) } persistenceEngine.removeApplication(app) schedule() // Tell all workers that the application has finished, so they can clean up any app state. workers.foreach { w => - w.actor ! ApplicationFinished(app.id) + w.endpoint.send(ApplicationFinished(app.id)) } } } @@ -768,7 +784,7 @@ private[master] class Master( } val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, None, app.desc.eventLogCodec) + eventLogDir, app.id, app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + EventLoggingListener.IN_PROGRESS)) @@ -832,14 +848,14 @@ private[master] class Master( private def timeOutDeadWorkers() { // Copy the workers into an array so we don't modify the hashset while iterating through it val currentTime = System.currentTimeMillis() - val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray + val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray for (worker <- toRemove) { if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( - worker.id, WORKER_TIMEOUT/1000)) + worker.id, WORKER_TIMEOUT_MS / 1000)) removeWorker(worker) } else { - if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) { + if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it } } @@ -862,7 +878,7 @@ private[master] class Master( logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) - worker.actor ! LaunchDriver(driver.id, driver.desc) + worker.endpoint.send(LaunchDriver(driver.id, driver.desc)) driver.state = DriverState.RUNNING } @@ -891,57 +907,33 @@ private[master] class Master( } private[deploy] object Master extends Logging { - val systemName = "sparkMaster" - private val actorName = "Master" + val SYSTEM_NAME = "sparkMaster" + val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) - actorSystem.awaitTermination() - } - - /** - * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaUrl(sparkUrl: String, protocol: String): String = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - AkkaUtils.address(protocol, systemName, host, port, actorName) - } - - /** - * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaAddress(sparkUrl: String, protocol: String): Address = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address(protocol, systemName, host, port) + val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf) + rpcEnv.awaitTermination() } /** - * Start the Master and return a four tuple of: - * (1) The Master actor system - * (2) The bound port - * (3) The web UI bound port - * (4) The REST server bound port, if any + * Start the Master and return a three tuple of: + * (1) The Master RpcEnv + * (2) The web UI bound port + * (3) The REST server bound port, if any */ - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { + conf: SparkConf): (RpcEnv, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = securityMgr) - val actor = actorSystem.actorOf( - Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) - val timeout = RpcUtils.askTimeout(conf) - val portsRequest = actor.ask(BoundPortsRequest)(timeout) - val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] - (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) + val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr) + val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, + new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf)) + val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest) + (rpcEnv, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index 15c6296888f70..68c937188b333 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -28,7 +28,7 @@ private[master] object MasterMessages { case object RevokedLeadership - // Actor System to Master + // Master to itself case object CheckForWorkerTimeOut diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 9b3d48c6edc84..471811037e5e2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.deploy.master import scala.collection.mutable -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -30,7 +28,7 @@ private[spark] class WorkerInfo( val port: Int, val cores: Int, val memory: Int, - val actor: ActorRef, + val endpoint: RpcEndpointRef, val webUiPort: Int, val publicAddress: String) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 52758d6a7c4be..6fdff86f66e01 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,10 +17,7 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef - import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} import org.apache.spark.deploy.SparkCuratorUtil diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 06e265f99e231..e28e7e379ac91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,11 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask - import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorDesc @@ -32,14 +29,12 @@ import org.apache.spark.util.Utils private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithRetry[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 6a7c74020bace..c3e20ebf8d6eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,25 +19,21 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master._ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef def getMasterState: MasterStateResponse = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - Await.result(stateFuture, timeout) + master.askWithRetry[MasterStateResponse](RequestMasterState) } override def renderJson(request: HttpServletRequest): JValue = { @@ -53,7 +49,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } def handleDriverKillRequest(request: HttpServletRequest): Unit = { - handleKillRequest(request, id => { master ! RequestKillDriver(id) }) + handleKillRequest(request, id => { + master.ask[KillDriverResponse](RequestKillDriver(id)) + }) } private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 2111a8581f2e4..6174fc11f83d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -23,7 +23,6 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone master. @@ -33,8 +32,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { - val masterActorRef = master.self - val timeout = RpcUtils.askTimeout(master.conf) + val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) val masterPage = new MasterPage(this) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 502b9bb701ccf..d5b9bcab1423f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -20,10 +20,10 @@ package org.apache.spark.deploy.rest import java.io.File import javax.servlet.http.HttpServletResponse -import akka.actor.ActorRef import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.Utils import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** @@ -45,35 +45,34 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to * @param masterConf the conf used by the Master - * @param masterActor reference to the Master actor to which requests can be sent + * @param masterEndpoint reference to the Master endpoint to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { protected override val submitRequestServlet = - new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) protected override val killRequestServlet = - new StandaloneKillRequestServlet(masterActor, masterConf) + new StandaloneKillRequestServlet(masterEndpoint, masterConf) protected override val statusRequestServlet = - new StandaloneStatusRequestServlet(masterActor, masterConf) + new StandaloneStatusRequestServlet(masterEndpoint, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( - DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId)) val k = new KillSubmissionResponse k.serverSparkVersion = sparkVersion k.message = response.message @@ -86,13 +85,12 @@ private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: Sp /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( - DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId)) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } val d = new SubmissionStatusResponse d.serverSparkVersion = sparkVersion @@ -110,7 +108,7 @@ private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ private[rest] class StandaloneSubmitRequestServlet( - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String, conf: SparkConf) extends SubmitRequestServlet { @@ -175,10 +173,9 @@ private[rest] class StandaloneSubmitRequestServlet( responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { requestMessage match { case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription)) val submitResponse = new CreateSubmissionResponse submitResponse.serverSparkVersion = sparkVersion submitResponse.message = response.message diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 1386055eb8c48..ec51c3d935d8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -21,7 +21,6 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -31,6 +30,7 @@ import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{Utils, Clock, SystemClock} /** @@ -43,7 +43,7 @@ private[deploy] class DriverRunner( val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerUrl: String, val securityManager: SecurityManager) extends Logging { @@ -107,7 +107,7 @@ private[deploy] class DriverRunner( finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) + worker.send(DriverStateChanged(driverId, state, finalException)) } }.start() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index fff17e1095042..29a5042285578 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -21,10 +21,10 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged @@ -41,7 +41,7 @@ private[deploy] class ExecutorRunner( val appDesc: ApplicationDescription, val cores: Int, val memory: Int, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerId: String, val host: String, val webUiPort: Int, @@ -91,7 +91,7 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) } /** Stop this executor runner, including killing the process it launched */ @@ -159,7 +159,7 @@ private[deploy] class ExecutorRunner( val exitCode = process.waitFor() state = ExecutorState.EXITED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => { logInfo("Runner thread for executor " + fullId + " interrupted") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ebc6cd76c6afd..82e9578bbcba5 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -21,15 +21,14 @@ import java.io.File import java.io.IOException import java.text.SimpleDateFormat import java.util.{UUID, Date} +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.concurrent.ExecutionContext import scala.util.Random - -import akka.actor._ -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} @@ -38,32 +37,39 @@ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} -/** - * @param masterAkkaUrls Each url should be a valid akka url. - */ private[worker] class Worker( - host: String, - port: Int, + override val rpcEnv: RpcEnv, webUiPort: Int, cores: Int, memory: Int, - masterAkkaUrls: Array[String], - actorSystemName: String, - actorName: String, + masterRpcAddresses: Array[RpcAddress], + systemName: String, + endpointName: String, workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { - import context.dispatcher + extends ThreadSafeRpcEndpoint with Logging { + + private val host = rpcEnv.address.host + private val port = rpcEnv.address.port Utils.checkHost(host, "Expected hostname") assert (port > 0) + // A scheduled executor used to send messages at the specified time. + private val forwordMessageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") + + // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` + // methods. + private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) + // For worker and executor IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -79,32 +85,26 @@ private[worker] class Worker( val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND } - private val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * - REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds - private val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 - * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + private val INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)) + private val PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)) private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") - private var master: ActorSelection = null - private var masterAddress: Address = null + private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" - private val akkaUrl = AkkaUtils.address( - AkkaUtils.protocol(context.system), - actorSystemName, - host, - port, - actorName) - @volatile private var registered = false - @volatile private var connected = false + private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private var registered = false + private var connected = false private val workerId = generateWorkerId() private val sparkHome = if (testing) { @@ -136,7 +136,18 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - private var registrationRetryTimer: Option[Cancellable] = None + private var registerMasterFutures: Array[JFuture[_]] = null + private var registrationRetryTimer: Option[JScheduledFuture[_]] = None + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) var coresUsed = 0 var memoryUsed = 0 @@ -162,14 +173,13 @@ private[worker] class Worker( } } - override def preStart() { + override def onStart() { assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -181,24 +191,32 @@ private[worker] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(url: String, uiUrl: String) { + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { // activeMasterUrl it's a valid Spark url since we receive it from master. - activeMasterUrl = url + activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system)) + master = Some(masterRef) connected = true // Cancel any outstanding re-registration attempts because we found a new master - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } - private def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + masterRpcAddresses.map { masterAddress => + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + }) } } @@ -211,8 +229,7 @@ private[worker] class Worker( Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") /** @@ -235,21 +252,48 @@ private[worker] class Worker( * still not safe if the old master recovers within this interval, but this is a much * less likely scenario. */ - if (master != null) { - master ! RegisterWorker( - workerId, host, port, cores, memory, webUi.boundPort, publicAddress) - } else { - // We are retrying the initial registration - tryRegisterAllMasters() + master match { + case Some(masterRef) => + // registered == false && master != None means we lost the connection to master, so + // masterRef cannot be used and we need to recreate it again. Note: we must not set + // master to None due to the above comments. + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + val masterAddress = masterRef.address + registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + })) + case None => + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + // We are retrying the initial registration + registerMasterFutures = tryRegisterAllMasters() } // We have exceeded the initial registration retry threshold // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = Some { - context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = Some( + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) } } else { logError("All masters are unresponsive! Giving up.") @@ -258,41 +302,67 @@ private[worker] class Worker( } } + /** + * Cancel last registeration retry, or do nothing if no retry + */ + private def cancelLastRegistrationRetry(): Unit = { + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures = null + } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = None + } + private def registerWithMaster() { - // DisassociatedEvent may be triggered multiple times, so don't attempt registration + // onDisconnected may be triggered multiple times, so don't attempt registration // if there are outstanding registration attempts scheduled. registrationRetryTimer match { case None => registered = false - tryRegisterAllMasters() + registerMasterFutures = tryRegisterAllMasters() connectionAttemptCount = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( + new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + " attempt scheduled already.") } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredWorker(masterUrl, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterUrl) + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) registered = true - changeMaster(masterUrl, masterWebUiUrl) - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + changeMaster(masterRef, masterWebUiUrl) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(SendHeartbeat) + } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") - context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, - CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } case SendHeartbeat => - if (connected) { master ! Heartbeat(workerId) } + if (connected) { sendToMaster(Heartbeat(workerId, self)) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor + // Copy ids so that it can be used in the cleanup thread. + val appIds = executors.values.map(_.appId).toSet val cleanupFuture = concurrent.future { val appDirs = workDir.listFiles() if (appDirs == null) { @@ -302,27 +372,27 @@ private[worker] class Worker( // the directory is used by an application - check that the application is not running // when cleaning up val appIdFromDir = dir.getName - val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir) + val isAppStillRunning = appIds.contains(appIdFromDir) dir.isDirectory && !isAppStillRunning && - !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS) + !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS) }.foreach { dir => logInfo(s"Removing directory: ${dir.getPath}") Utils.deleteRecursively(dir) } - } + }(cleanupThreadExecutor) - cleanupFuture onFailure { + cleanupFuture.onFailure { case e: Throwable => logError("App dir cleanup failed: " + e.getMessage, e) - } + }(cleanupThreadExecutor) - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl, masterWebUiUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + changeMaster(masterRef, masterWebUiUrl) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) + masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) case RegisterWorkerFailed(message) => if (!registered) { @@ -369,14 +439,14 @@ private[worker] class Worker( publicAddress, sparkHome, executorDir, - akkaUrl, + workerUri, conf, appLocalDirs, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => { logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) @@ -384,14 +454,14 @@ private[worker] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, - Some(e.toString), None) + sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None)) } } } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => + sendToMaster(executorStateChanged) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -434,7 +504,7 @@ private[worker] class Worker( sparkHome, driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), self, - akkaUrl, + workerUri, securityMgr) drivers(driverId) = driver driver.start() @@ -453,7 +523,7 @@ private[worker] class Worker( } } - case DriverStateChanged(driverId, state, exception) => { + case driverStageChanged @ DriverStateChanged(driverId, state, exception) => { state match { case DriverState.ERROR => logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") @@ -466,23 +536,13 @@ private[worker] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - master ! DriverStateChanged(driverId, state, exception) + sendToMaster(driverStageChanged) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem coresUsed -= driver.driverDesc.cores } - case x: DisassociatedEvent if x.remoteAddress == masterAddress => - logInfo(s"$x Disassociated !") - masterDisconnected() - - case RequestWorkerState => - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, drivers.values.toList, - finishedDrivers.values.toList, activeMasterUrl, cores, memory, - coresUsed, memoryUsed, activeMasterWebUiUrl) - case ReregisterWithMaster => reregisterWithMaster() @@ -491,6 +551,21 @@ private[worker] class Worker( maybeCleanupApplication(id) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestWorkerState => + context.reply(WorkerStateResponse(host, port, workerId, executors.values.toList, + finishedExecutors.values.toList, drivers.values.toList, + finishedDrivers.values.toList, activeMasterUrl, cores, memory, + coresUsed, memoryUsed, activeMasterWebUiUrl)) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (master.exists(_.address == remoteAddress)) { + logInfo(s"$remoteAddress Disassociated !") + masterDisconnected() + } + } + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false @@ -510,13 +585,29 @@ private[worker] class Worker( } } + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => + logWarning( + s"Dropping $message because the connection to master has not yet been established") + } + } + private def generateWorkerId(): String = { "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } - override def postStop() { + override def onStop() { + cleanupThreadExecutor.shutdownNow() metricsSystem.report() - registrationRetryTimer.foreach(_.cancel()) + cancelLastRegistrationRetry() + forwordMessageScheduler.shutdownNow() + registerMasterThreadPool.shutdownNow() executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) shuffleService.stop() @@ -530,12 +621,12 @@ private[deploy] object Worker extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, @@ -544,18 +635,17 @@ private[deploy] object Worker extends Logging { masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None, - conf: SparkConf = new SparkConf): (ActorSystem, Int) = { + conf: SparkConf = new SparkConf): RpcEnv = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) - actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) - (actorSystem, boundPort) + val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) + val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) + rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, webUiPort, cores, memory, masterAddresses, + systemName, actorName, workDir, conf, securityMgr)) + rpcEnv } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 83fb991891a41..fae5640b9a213 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.worker import org.apache.spark.Logging -import org.apache.spark.deploy.DeployMessages.SendHeartbeat import org.apache.spark.rpc._ /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 9f9f27d71e1ae..fd905feb97e92 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,10 +17,8 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import javax.servlet.http.HttpServletRequest import org.json4s.JValue @@ -32,18 +30,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { - private val workerActor = parent.worker.self - private val timeout = parent.timeout + private val workerEndpoint = parent.worker.self override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 12b6b28d4d7ec..3b6938ec639c3 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -158,6 +158,8 @@ private[spark] case class RpcAddress(host: String, port: Int) { val hostPort: String = host + ":" + port override val toString: String = hostPort + + def toSparkURL: String = "spark://" + hostPort } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 0161962cde073..31ebe5ac5bca3 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -180,10 +180,10 @@ private[spark] class AkkaRpcEnv private[akka] ( }) } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - _sender ! AkkaFailure(e) - } else { + _sender ! AkkaFailure(e) + if (!needReply) { + // If the sender does not require a reply, it may not handle the exception. So we rethrow + // "e" to make sure it will be processed. throw e } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ccf1dc5af6120..687ae9620460f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -85,7 +85,7 @@ private[spark] class SparkDeploySchedulerBackend( val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) - client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) + client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() waitForRegistration() } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 014e87bb40254..9cb6dd43bac47 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -19,63 +19,21 @@ package org.apache.spark.deploy.master import java.util.Date -import scala.concurrent.Await import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps -import akka.actor.Address import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy._ class MasterSuite extends SparkFunSuite with Matchers with Eventually { - test("toAkkaUrl") { - val conf = new SparkConf(loadDefaults = false) - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.tcp") - assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) - } - - test("toAkkaUrl with SSL") { - val conf = new SparkConf(loadDefaults = false) - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.ssl.tcp") - assert("akka.ssl.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) - } - - test("toAkkaUrl: a typo url") { - val conf = new SparkConf(loadDefaults = false) - val e = intercept[SparkException] { - Master.toAkkaUrl("spark://1.2. 3.4:1234", "akka.tcp") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } - - test("toAkkaAddress") { - val conf = new SparkConf(loadDefaults = false) - val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.tcp") - assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address) - } - - test("toAkkaAddress with SSL") { - val conf = new SparkConf(loadDefaults = false) - val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.ssl.tcp") - assert(Address("akka.ssl.tcp", "sparkMaster", "1.2.3.4", 1234) === address) - } - - test("toAkkaAddress: a typo url") { - val conf = new SparkConf(loadDefaults = false) - val e = intercept[SparkException] { - Master.toAkkaAddress("spark://1.2. 3.4:1234", "akka.tcp") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } - test("can use a custom recovery mode factory") { val conf = new SparkConf(loadDefaults = false) conf.set("spark.deploy.recoveryMode", "CUSTOM") @@ -129,16 +87,16 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { port = 10000, cores = 0, memory = 0, - actor = null, + endpoint = null, webUiPort = 0, publicAddress = "" ) - val (actorSystem, port, uiPort, restPort) = - Master.startSystemAndActor("127.0.0.1", 7077, 8080, conf) + val (rpcEnv, uiPort, restPort) = + Master.startRpcEnvAndEndpoint("127.0.0.1", 7077, 8080, conf) try { - Await.result(actorSystem.actorSelection("/user/Master").resolveOne(10 seconds), 10 seconds) + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) CustomPersistenceEngine.lastInstance.isDefined shouldBe true val persistenceEngine = CustomPersistenceEngine.lastInstance.get @@ -154,8 +112,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { workers.map(_.id) should contain(workerToPersist.id) } finally { - actorSystem.shutdown() - actorSystem.awaitTermination() + rpcEnv.shutdown() + rpcEnv.awaitTermination() } CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 197f68e7ec5ed..96e456d889ac3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -23,14 +23,14 @@ import javax.servlet.http.HttpServletResponse import scala.collection.mutable -import akka.actor.{Actor, ActorRef, ActorSystem, Props} import com.google.common.base.Charsets import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.rpc._ +import org.apache.spark.util.Utils import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} import org.apache.spark.deploy.master.DriverState._ @@ -39,11 +39,11 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { - private var actorSystem: Option[ActorSystem] = None + private var rpcEnv: Option[RpcEnv] = None private var server: Option[RestSubmissionServer] = None override def afterEach() { - actorSystem.foreach(_.shutdown()) + rpcEnv.foreach(_.shutdown()) server.foreach(_.stop()) } @@ -377,31 +377,32 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { killMessage: String = "driver is killed", state: DriverState = FINISHED, exception: Option[Exception] = None): String = { - startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception)) + startServer(new DummyMaster(_, submitId, submitMessage, killMessage, state, exception)) } /** Start a smarter dummy server that keeps track of submitted driver states. */ private def startSmartServer(): String = { - startServer(new SmarterMaster) + startServer(new SmarterMaster(_)) } /** Start a dummy server that is faulty in many ways... */ private def startFaultyServer(): String = { - startServer(new DummyMaster, faulty = true) + startServer(new DummyMaster(_), faulty = true) } /** - * Start a [[StandaloneRestServer]] that communicates with the given actor. + * Start a [[StandaloneRestServer]] that communicates with the given endpoint. * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead. * Return the master URL that corresponds to the address of this server. */ - private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = { + private def startServer( + makeFakeMaster: RpcEnv => RpcEndpoint, faulty: Boolean = false): String = { val name = "test-standalone-rest-protocol" val conf = new SparkConf val localhost = Utils.localHostName() val securityManager = new SecurityManager(conf) - val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager) - val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) + val _rpcEnv = RpcEnv.create(name, localhost, 0, conf, securityManager) + val fakeMasterRef = _rpcEnv.setupEndpoint("fake-master", makeFakeMaster(_rpcEnv)) val _server = if (faulty) { new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") @@ -410,7 +411,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { } val port = _server.start() // set these to clean them up after every test - actorSystem = Some(_actorSystem) + rpcEnv = Some(_rpcEnv) server = Some(_server) s"spark://$localhost:$port" } @@ -505,20 +506,21 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { * In all responses, the success parameter is always true. */ private class DummyMaster( + override val rpcEnv: RpcEnv, submitId: String = "fake-driver-id", submitMessage: String = "submitted", killMessage: String = "killed", state: DriverState = FINISHED, exception: Option[Exception] = None) - extends Actor { + extends RpcEndpoint { - override def receive: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => - sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) + context.reply(SubmitDriverResponse(self, success = true, Some(submitId), submitMessage)) case RequestKillDriver(driverId) => - sender ! KillDriverResponse(driverId, success = true, killMessage) + context.reply(KillDriverResponse(self, driverId, success = true, killMessage)) case RequestDriverStatus(driverId) => - sender ! DriverStatusResponse(found = true, Some(state), None, None, exception) + context.reply(DriverStatusResponse(found = true, Some(state), None, None, exception)) } } @@ -531,28 +533,28 @@ private class DummyMaster( * Submits are always successful while kills and status requests are successful only * if the driver was submitted in the past. */ -private class SmarterMaster extends Actor { +private class SmarterMaster(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { private var counter: Int = 0 private val submittedDrivers = new mutable.HashMap[String, DriverState] - override def receive: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => val driverId = s"driver-$counter" submittedDrivers(driverId) = RUNNING counter += 1 - sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted") + context.reply(SubmitDriverResponse(self, success = true, Some(driverId), "submitted")) case RequestKillDriver(driverId) => val success = submittedDrivers.contains(driverId) if (success) { submittedDrivers(driverId) = KILLED } - sender ! KillDriverResponse(driverId, success, "killed") + context.reply(KillDriverResponse(self, driverId, success, "killed")) case RequestDriverStatus(driverId) => val found = submittedDrivers.contains(driverId) val state = submittedDrivers.get(driverId) - sender ! DriverStatusResponse(found, state, None, None, None) + context.reply(DriverStatusResponse(found, state, None, None, None)) } } @@ -568,7 +570,7 @@ private class FaultyStandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { @@ -578,7 +580,7 @@ private class FaultyStandaloneRestServer( /** A faulty servlet that produces malformed responses. */ class MalformedSubmitServlet - extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) { + extends StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) { protected override def sendResponse( responseMessage: SubmitRestProtocolResponse, responseServlet: HttpServletResponse): Unit = { @@ -588,7 +590,7 @@ private class FaultyStandaloneRestServer( } /** A faulty servlet that produces invalid responses. */ - class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) { + class InvalidKillServlet extends StandaloneKillRequestServlet(masterEndpoint, masterConf) { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = super.handleKill(submissionId) k.submissionId = null @@ -597,7 +599,7 @@ private class FaultyStandaloneRestServer( } /** A faulty status servlet that explodes. */ - class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) { + class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterEndpoint, masterConf) { private def explode: Int = 1 / 0 protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val s = super.handleStatus(submissionId) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index ac18f04a11475..cd24d79423316 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.worker -import akka.actor.AddressFromURIString import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -26,13 +25,11 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" - val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) + val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected( - RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) rpcEnv.shutdown() } @@ -40,13 +37,13 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher stays alive on invalid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" - val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" - val otherAkkaAddress = AddressFromURIString(otherAkkaURL) + val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val otherAddress = "akka://test@4.3.2.1:1234/user/OtherActor" + val otherAkkaAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + workerWatcher.onDisconnected(otherAkkaAddress) assert(!workerWatcher.isShutDown) rpcEnv.shutdown() } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala new file mode 100644 index 0000000000000..b3223ec61bf79 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.{SparkException, SparkFunSuite} + +class RpcAddressSuite extends SparkFunSuite { + + test("hostPort") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.host == "1.2.3.4") + assert(address.port == 1234) + assert(address.hostPort == "1.2.3.4:1234") + } + + test("fromSparkURL") { + val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234") + assert(address.host == "1.2.3.4") + assert(address.port == 1234) + } + + test("fromSparkURL: a typo url") { + val e = intercept[SparkException] { + RpcAddress.fromSparkURL("spark://1.2. 3.4:1234") + } + assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) + } + + test("fromSparkURL: invalid scheme") { + val e = intercept[SparkException] { + RpcAddress.fromSparkURL("invalid://1.2.3.4:1234") + } + assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage) + } + + test("toSparkURL") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.toSparkURL == "spark://1.2.3.4:1234") + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index a33a83db7bc9e..4aa75c9230b2c 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.rpc.akka import org.apache.spark.rpc._ -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} class AkkaRpcEnvSuite extends RpcEnvSuite { @@ -47,4 +47,22 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { } } + test("uriOf") { + val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") + assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) + } + + test("uriOf: ssl") { + val conf = SSLSampleConfigs.sparkSSLConfig() + val securityManager = new SecurityManager(conf) + val rpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, securityManager)) + try { + val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") + assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) + } finally { + rpcEnv.shutdown() + } + } + } From f457569886e9de9256ad269cb4a3d73a8918766d Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 30 Jun 2015 20:19:43 -0700 Subject: [PATCH 067/274] [SPARK-8471] [ML] Rename DiscreteCosineTransformer to DCT Rename DiscreteCosineTransformer and related classes to DCT. Author: Feynman Liang Closes #7138 from feynmanliang/dct-features and squashes the following commits: e547b3e [Feynman Liang] Fix renaming bug 9d5c9e4 [Feynman Liang] Lowercase JavaDCTSuite variable f9a8958 [Feynman Liang] Remove old files f8fe794 [Feynman Liang] Merge branch 'master' into dct-features 894d0b2 [Feynman Liang] Rename DiscreteCosineTransformer to DCT 433dbc7 [Feynman Liang] Test refactoring 91e9636 [Feynman Liang] Style guide and test helper refactor b5ac19c [Feynman Liang] Use Vector types, add Java test 530983a [Feynman Liang] Tests for other numeric datatypes 195d7aa [Feynman Liang] Implement support for arbitrary numeric types 95d4939 [Feynman Liang] Working DCT for 1D Doubles --- .../{DiscreteCosineTransformer.scala => DCT.scala} | 4 ++-- ...creteCosineTransformerSuite.java => JavaDCTSuite.java} | 8 ++++---- ...iscreteCosineTransformerSuite.scala => DCTSuite.scala} | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/feature/{DiscreteCosineTransformer.scala => DCT.scala} (95%) rename mllib/src/test/java/org/apache/spark/ml/feature/{JavaDiscreteCosineTransformerSuite.java => JavaDCTSuite.java} (90%) rename mllib/src/test/scala/org/apache/spark/ml/feature/{DiscreteCosineTransformerSuite.scala => DCTSuite.scala} (94%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala similarity index 95% rename from mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala rename to mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index a2f4d59f81c44..228347635c92b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.types.DataType * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]]. */ @Experimental -class DiscreteCosineTransformer(override val uid: String) - extends UnaryTransformer[Vector, Vector, DiscreteCosineTransformer] { +class DCT(override val uid: String) + extends UnaryTransformer[Vector, Vector, DCT] { def this() = this(Identifiable.randomUID("dct")) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java similarity index 90% rename from mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java rename to mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 28bc5f65e0532..845eed61c45c6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -37,13 +37,13 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaDiscreteCosineTransformerSuite { +public class JavaDCTSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaDiscreteCosineTransformerSuite"); + jsc = new JavaSparkContext("local", "JavaDCTSuite"); jsql = new SQLContext(jsc); } @@ -66,11 +66,11 @@ public void javaCompatibilityTest() { double[] expectedResult = input.clone(); (new DoubleDCT_1D(input.length)).forward(expectedResult, true); - DiscreteCosineTransformer DCT = new DiscreteCosineTransformer() + DCT dct = new DCT() .setInputCol("vec") .setOutputCol("resultVec"); - Row[] result = DCT.transform(dataset).select("resultVec").collect(); + Row[] result = dct.transform(dataset).select("resultVec").collect(); Vector resultVec = result[0].getAs("resultVec"); Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala similarity index 94% rename from mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index ed0fc11f78f69..37ed2367c33f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DiscreteCosineTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { test("forward transform of discrete cosine matches jTransforms result") { val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) @@ -58,7 +58,7 @@ class DiscreteCosineTransformerSuite extends SparkFunSuite with MLlibTestSparkCo DCTTestData(data, expectedResult) )) - val transformer = new DiscreteCosineTransformer() + val transformer = new DCT() .setInputCol("vec") .setOutputCol("resultVec") .setInverse(inverse) From b6e76edf3005c078b407f63b0a05d3a28c18c742 Mon Sep 17 00:00:00 2001 From: x1- Date: Tue, 30 Jun 2015 20:35:46 -0700 Subject: [PATCH 068/274] [SPARK-8535] [PYSPARK] PySpark : Can't create DataFrame from Pandas dataframe with no explicit column name Because implicit name of `pandas.columns` are Int, but `StructField` json expect `String`. So I think `pandas.columns` are should be convert to `String`. ### issue * [SPARK-8535 PySpark : Can't create DataFrame from Pandas dataframe with no explicit column name](https://issues.apache.org/jira/browse/SPARK-8535) Author: x1- Closes #7124 from x1-/SPARK-8535 and squashes the following commits: d68fd38 [x1-] modify unit-test using pandas. ea1897d [x1-] For implicit name of pandas.columns are Int, so should be convert to String. --- python/pyspark/sql/context.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4bf232111c496..309c11faf9319 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -344,13 +344,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] + >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect()) # doctest: +SKIP + [Row(0=1, 1=2)] """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") if has_pandas and isinstance(data, pandas.DataFrame): if schema is None: - schema = list(data.columns) + schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): From 64c14618d3f4ede042bd3f6a542bc17a730afb0e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Jun 2015 21:57:07 -0700 Subject: [PATCH 069/274] [SPARK-6602][Core]Remove unnecessary synchronized A follow-up pr to address https://github.com/apache/spark/pull/5392#discussion_r33627528 Author: zsxwing Closes #7141 from zsxwing/pr5392-follow-up and squashes the following commits: fcf7b50 [zsxwing] Remove unnecessary synchronized --- .../main/scala/org/apache/spark/deploy/master/Master.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 3e7c16722805e..48070768f6edb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -518,12 +518,9 @@ private[master] class Master( } private def completeRecovery() { - // TODO Why synchronized // Ensure "only-once" recovery semantics using a short synchronization period. - synchronized { - if (state != RecoveryState.RECOVERING) { return } - state = RecoveryState.COMPLETING_RECOVERY - } + if (state != RecoveryState.RECOVERING) { return } + state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) From 365c14055e90db5ea4b25afec03022be81c8a704 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Jun 2015 23:04:54 -0700 Subject: [PATCH 070/274] [SPARK-8748][SQL] Move castability test out from Cast case class into Cast object. This patch moved resolve function in Cast case class into the companion object, and renamed it canCast. We can then use this in the analyzer without a Cast expr. Author: Reynold Xin Closes #7145 from rxin/cast and squashes the following commits: cd086a9 [Reynold Xin] Whitespace changes. 4d2d989 [Reynold Xin] [SPARK-8748][SQL] Move castability test out from Cast case class into Cast object. --- .../spark/sql/catalyst/expressions/Cast.scala | 144 ++++++++++-------- 1 file changed, 78 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d69d490ad666a..2d99d1a3fe8dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -27,23 +27,65 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -/** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - override def checkInputDataTypes(): TypeCheckResult = { - if (resolve(child.dataType, dataType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType} to $dataType") - } - } +object Cast { - override def foldable: Boolean = child.foldable + /** + * Returns true iff we can cast `from` type to `to` type. + */ + def canCast(from: DataType, to: DataType): Boolean = (from, to) match { + case (fromType, toType) if fromType == toType => true + + case (NullType, _) => true + + case (_, StringType) => true - override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable + case (StringType, BinaryType) => true - private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { + case (StringType, BooleanType) => true + case (DateType, BooleanType) => true + case (TimestampType, BooleanType) => true + case (_: NumericType, BooleanType) => true + + case (StringType, TimestampType) => true + case (BooleanType, TimestampType) => true + case (DateType, TimestampType) => true + case (_: NumericType, TimestampType) => true + + case (_, DateType) => true + + case (StringType, _: NumericType) => true + case (BooleanType, _: NumericType) => true + case (DateType, _: NumericType) => true + case (TimestampType, _: NumericType) => true + case (_: NumericType, _: NumericType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + canCast(fromType, toType) && + resolvableNullability(fn || forceNullable(fromType, toType), tn) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + canCast(fromKey, toKey) && + (!forceNullable(fromKey, toKey)) && + canCast(fromValue, toValue) && + resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (fromField, toField) => + canCast(fromField.dataType, toField.dataType) && + resolvableNullability( + fromField.nullable || forceNullable(fromField.dataType, toField.dataType), + toField.nullable) + } + + case _ => false + } + + private def resolvableNullability(from: Boolean, to: Boolean) = !from || to + + private def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true case (DoubleType, TimestampType) => true @@ -58,61 +100,24 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null case _ => false } +} - private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to - - private[this] def resolve(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (from, to) if from == to => true - - case (NullType, _) => true - - case (_, StringType) => true - - case (StringType, BinaryType) => true - - case (StringType, BooleanType) => true - case (DateType, BooleanType) => true - case (TimestampType, BooleanType) => true - case (_: NumericType, BooleanType) => true - - case (StringType, TimestampType) => true - case (BooleanType, TimestampType) => true - case (DateType, TimestampType) => true - case (_: NumericType, TimestampType) => true - - case (_, DateType) => true - - case (StringType, _: NumericType) => true - case (BooleanType, _: NumericType) => true - case (DateType, _: NumericType) => true - case (TimestampType, _: NumericType) => true - case (_: NumericType, _: NumericType) => true - - case (ArrayType(from, fn), ArrayType(to, tn)) => - resolve(from, to) && - resolvableNullability(fn || forceNullable(from, to), tn) - - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - resolve(fromKey, toKey) && - (!forceNullable(fromKey, toKey)) && - resolve(fromValue, toValue) && - resolvableNullability(fn || forceNullable(fromValue, toValue), tn) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.size == toFields.size && - fromFields.zip(toFields).forall { - case (fromField, toField) => - resolve(fromField.dataType, toField.dataType) && - resolvableNullability( - fromField.nullable || forceNullable(fromField.dataType, toField.dataType), - toField.nullable) - } +/** Cast the child expression to the target data type. */ +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - case _ => false + override def checkInputDataTypes(): TypeCheckResult = { + if (Cast.canCast(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"cannot cast ${child.dataType} to $dataType") } } + override def foldable: Boolean = child.foldable + + override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable + override def toString: String = s"CAST($child, $dataType)" // [[func]] assumes the input is no longer null because eval already does the null check. @@ -172,7 +177,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => (if (b) 1L else 0)) + buildCast[Boolean](_, b => if (b) 1L else 0) case LongType => buildCast[Long](_, l => longToTimestamp(l)) case IntegerType => @@ -388,7 +393,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericMutableRow(from.fields.size) + val newRow = new GenericMutableRow(from.fields.length) buildCast[InternalRow](_, row => { var i = 0 while (i < row.length) { @@ -427,20 +432,23 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - // TODO(cg): Add support for more data types. + // TODO: Add support for more data types. (child.dataType, dataType) match { case (BinaryType, StringType) => defineCodeGen (ctx, ev, c => s"${ctx.stringType}.fromBytes($c)") + case (DateType, StringType) => defineCodeGen(ctx, ev, c => s"""${ctx.stringType}.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") + case (TimestampType, StringType) => defineCodeGen(ctx, ev, c => s"""${ctx.stringType}.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") + case (_, StringType) => defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") @@ -450,12 +458,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BooleanType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + case (dt: DecimalType, BooleanType) => defineCodeGen(ctx, ev, c => s"!$c.isZero()") + case (dt: NumericType, BooleanType) => defineCodeGen(ctx, ev, c => s"$c != 0") + case (_: DecimalType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") + case (_: NumericType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") From fc3a6fe67f5aeda2443958c31f097daeba8549e5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 00:08:16 -0700 Subject: [PATCH 071/274] [SPARK-8749][SQL] Remove HiveTypeCoercion trait. Moved all the rules into the companion object. Author: Reynold Xin Closes #7147 from rxin/SPARK-8749 and squashes the following commits: c1c6dc0 [Reynold Xin] [SPARK-8749][SQL] Remove HiveTypeCoercion trait. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 59 ++++++++----------- .../analysis/HiveTypeCoercionSuite.scala | 14 ++--- 3 files changed, 33 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 117c87a785fdb..15e84e68b9881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -43,7 +43,7 @@ class Analyzer( registry: FunctionRegistry, conf: CatalystConf, maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { + extends RuleExecutor[LogicalPlan] with CheckAnalysis { def resolver: Resolver = { if (conf.caseSensitiveAnalysis) { @@ -76,7 +76,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - typeCoercionRules ++ + HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e525ad623ff12..a9d396d1faeeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -22,7 +22,32 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ + +/** + * A collection of [[Rule Rules]] that can be used to coerce differing types that + * participate in operations into compatible ones. Most of these rules are based on Hive semantics, + * but they do not introduce any dependencies on the hive codebase. For this reason they remain in + * Catalyst until we have a more standard set of coercions. + */ object HiveTypeCoercion { + + val typeCoercionRules = + PropagateTypes :: + ConvertNaNs :: + InConversion :: + WidenTypes :: + PromoteStrings :: + DecimalPrecision :: + BooleanEquality :: + StringToIntegralCasts :: + FunctionArgumentConversion :: + CaseWhenCoercion :: + IfCoercion :: + Division :: + PropagateTypes :: + AddCastForAutoCastInputTypes :: + Nil + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: private val numericPrecedence = @@ -79,7 +104,6 @@ object HiveTypeCoercion { }) } - /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -90,34 +114,6 @@ object HiveTypeCoercion { case Some(d) => findTightestCommonTypeOfTwo(d, c) }) } -} - -/** - * A collection of [[Rule Rules]] that can be used to coerce differing types that - * participate in operations into compatible ones. Most of these rules are based on Hive semantics, - * but they do not introduce any dependencies on the hive codebase. For this reason they remain in - * Catalyst until we have a more standard set of coercions. - */ -trait HiveTypeCoercion { - - import HiveTypeCoercion._ - - val typeCoercionRules = - PropagateTypes :: - ConvertNaNs :: - InConversion :: - WidenTypes :: - PromoteStrings :: - DecimalPrecision :: - BooleanEquality :: - StringToIntegralCasts :: - FunctionArgumentConversion :: - CaseWhenCoercion :: - IfCoercion :: - Division :: - PropagateTypes :: - AddCastForAutoCastInputTypes :: - Nil /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to @@ -202,8 +198,6 @@ trait HiveTypeCoercion { * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - import HiveTypeCoercion._ - def apply(plan: LogicalPlan): LogicalPlan = plan transform { // TODO: unions with fixed-precision decimals case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -655,8 +649,6 @@ trait HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - import HiveTypeCoercion._ - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") @@ -714,7 +706,6 @@ trait HiveTypeCoercion { * [[AutoCastInputTypes]]. */ object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f7b8e21bed490..eae3666595a38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -113,8 +113,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("coalesce casts") { - val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - ruleTest(fac, + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -123,7 +122,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(fac, + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -135,7 +134,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for If") { - val rule = new HiveTypeCoercion { }.IfCoercion + val rule = HiveTypeCoercion.IfCoercion ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), If(Literal(true), Cast(Literal(1), LongType), Literal(1L)) @@ -148,19 +147,18 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for CaseKeyWhen") { - val cwc = new HiveTypeCoercion {}.CaseWhenCoercion - ruleTest(cwc, + ruleTest(HiveTypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(cwc, + ruleTest(HiveTypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) } test("type coercion simplification for equal to") { - val be = new HiveTypeCoercion {}.BooleanEquality + val be = HiveTypeCoercion.BooleanEquality ruleTest(be, EqualTo(Literal(true), Literal(1)), From 0eee0615894cda8ae1b2c8e61b8bda0ff648a219 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Jul 2015 01:02:33 -0700 Subject: [PATCH 072/274] [SQL] [MINOR] remove internalRowRDD in DataFrame Developers have already familiar with `queryExecution.toRDD` as internal row RDD, and we should not add new concept. Author: Wenchen Fan Closes #7116 from cloud-fan/internal-rdd and squashes the following commits: 24756ca [Wenchen Fan] remove internalRowRDD --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 4 +--- .../org/apache/spark/sql/execution/stat/FrequentItems.scala | 2 +- .../org/apache/spark/sql/execution/stat/StatFunctions.scala | 2 +- .../main/scala/org/apache/spark/sql/sources/commands.scala | 4 ++-- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8fe1f7e34cb5e..caad2da80b1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1469,14 +1469,12 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - internalRowRdd.mapPartitions { rows => + queryExecution.toRdd.mapPartitions { rows => val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]) } } - private[sql] def internalRowRdd = queryExecution.executedPlan.execute() - /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 3ebbf96090a55..4e2e2c210d5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) } - val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index b624ef7e8fa1a..23ddfa9839e5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -82,7 +82,7 @@ private[sql] object StatFunctions extends Logging { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 42b51caab5ce9..7214eb0b4169a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => From 97652416e22ae7d4c471178377a7dda61afb1f7a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 01:08:20 -0700 Subject: [PATCH 073/274] [SPARK-8750][SQL] Remove the closure in functions.callUdf. Author: Reynold Xin Closes #7148 from rxin/calludf-closure and squashes the following commits: 00df372 [Reynold Xin] Fixed index out of bound exception. 4beba76 [Reynold Xin] [SPARK-8750][SQL] Remove the closure in functions.callUdf. --- .../main/scala/org/apache/spark/sql/functions.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5767668dd339b..4e8f3f96bf4db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1829,7 +1829,15 @@ object functions { */ @deprecated("Use callUDF", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + // Note: we avoid using closures here because on file systems that are case-insensitive, the + // compiled class file for the closure here will conflict with the one in callUDF (upper case). + val exprs = new Array[Expression](cols.size) + var i = 0 + while (i < cols.size) { + exprs(i) = cols(i).expr + i += 1 + } + UnresolvedFunction(udfName, exprs) } } From fdcad6ef48a9e790776c316124bd6478ab6bd5c8 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Wed, 1 Jul 2015 09:37:09 -0700 Subject: [PATCH 074/274] [SPARK-8763] [PYSPARK] executing run-tests.py with Python 2.6 fails with absence of subprocess.check_output function Running run-tests.py with Python 2.6 cause following error: ``` Running PySpark tests. Output is in python//Users/tomohiko/.jenkins/jobs/pyspark_test/workspace/python/unit-tests.log Will test against the following Python executables: ['python2.6', 'python3.4', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Traceback (most recent call last): File "./python/run-tests.py", line 196, in main() File "./python/run-tests.py", line 159, in main python_implementation = subprocess.check_output( AttributeError: 'module' object has no attribute 'check_output' ... ``` The cause of this error is using subprocess.check_output function, which exists since Python 2.7. (ref. https://docs.python.org/2.7/library/subprocess.html#subprocess.check_output) Author: cocoatomo Closes #7161 from cocoatomo/issues/8763-test-fails-py26 and squashes the following commits: cf4f901 [cocoatomo] [SPARK-8763] backport process.check_output function from Python 2.7 --- python/run-tests.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/run-tests.py b/python/run-tests.py index b7737650daa54..7638854def2e8 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -31,6 +31,23 @@ import Queue else: import queue as Queue +if sys.version_info >= (2, 7): + subprocess_check_output = subprocess.check_output +else: + # SPARK-8763 + # backported from subprocess module in Python 2.7 + def subprocess_check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -156,11 +173,11 @@ def main(): task_queue = Queue.Queue() for python_exec in python_execs: - python_implementation = subprocess.check_output( + python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) - LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output( + LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output( [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: From 69c5dee2f01b1ae35bd813d31d46429a32cb475d Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 1 Jul 2015 09:50:12 -0700 Subject: [PATCH 075/274] [SPARK-7714] [SPARKR] SparkR tests should use more specific expectations than expect_true 1. Update the pattern 'expect_true(a == b)' to 'expect_equal(a, b)'. 2. Update the pattern 'expect_true(inherits(a, b))' to 'expect_is(a, b)'. 3. Update the pattern 'expect_true(identical(a, b))' to 'expect_identical(a, b)'. Author: Sun Rui Closes #7152 from sun-rui/SPARK-7714 and squashes the following commits: 8ad2440 [Sun Rui] Fix test case errors. 8fe9f0c [Sun Rui] Update the pattern 'expect_true(identical(a, b))' to 'expect_identical(a, b)'. f1b8005 [Sun Rui] Update the pattern 'expect_true(inherits(a, b))' to 'expect_is(a, b)'. f631e94 [Sun Rui] Update the pattern 'expect_true(a == b)' to 'expect_equal(a, b)'. --- R/pkg/inst/tests/test_binaryFile.R | 2 +- R/pkg/inst/tests/test_binary_function.R | 4 +- R/pkg/inst/tests/test_includeJAR.R | 4 +- R/pkg/inst/tests/test_parallelize_collect.R | 2 +- R/pkg/inst/tests/test_rdd.R | 4 +- R/pkg/inst/tests/test_sparkSQL.R | 354 ++++++++++---------- R/pkg/inst/tests/test_take.R | 8 +- R/pkg/inst/tests/test_textFile.R | 6 +- R/pkg/inst/tests/test_utils.R | 4 +- 9 files changed, 194 insertions(+), 194 deletions(-) diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index 4db7266abc8e2..ccaea18ecab2a 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -82,7 +82,7 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index a1e354e567be5..3be8c65a6c1a0 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -38,13 +38,13 @@ test_that("union on two RDDs", { union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") rdd<- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") unlink(fileName) }) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R index 8bc693be20c3c..844d86f3cc97f 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -31,7 +31,7 @@ runScript <- function() { test_that("sparkJars tag in SparkContext", { testOutput <- runScript() helloTest <- testOutput[1] - expect_true(helloTest == "Hello, Dave") + expect_equal(helloTest, "Hello, Dave") basicFunction <- testOutput[2] - expect_true(basicFunction == 4L) + expect_equal(basicFunction, "4") }) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R index fff028657db37..2552127cc547f 100644 --- a/R/pkg/inst/tests/test_parallelize_collect.R +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -57,7 +57,7 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { strListRDD2) for (rdd in rdds) { - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(.hasSlot(rdd, "jrdd") && inherits(rdd@jrdd, "jobj") && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 4fe653856756e..fc3c01d837de4 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -33,9 +33,9 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_true(first(rdd) == 1) + expect_equal(first(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_equal(first(newrdd), 2) }) test_that("count and length on RDD", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 6a08f894313c4..0e4235ea8b4b3 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -61,7 +61,7 @@ test_that("infer types", { expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) testStruct <- infer_type(list(a = 1L, b = "2")) - expect_true(class(testStruct) == "structType") + expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() @@ -73,39 +73,39 @@ test_that("infer types", { test_that("structType and structField", { testField <- structField("a", "string") - expect_true(inherits(testField, "structField")) - expect_true(testField$name() == "a") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") expect_true(testField$nullable()) testSchema <- structType(testField, structField("b", "integer")) - expect_true(inherits(testSchema, "structType")) - expect_true(inherits(testSchema$fields()[[2]], "structField")) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlContext, rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -150,26 +150,26 @@ test_that("convert NAs to null type in DataFrames", { test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -219,21 +219,21 @@ test_that("create DataFrame with different data types", { test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) }) test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) - expect_true(count(rdd) == 3) + expect_equal(count(rdd), 3) df <- jsonRDD(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) df <- jsonRDD(sqlContext, rdd2) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 6) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) }) test_that("test cache, uncache and clearCache", { @@ -248,9 +248,9 @@ test_that("test cache, uncache and clearCache", { test_that("test tableNames and tables", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlContext)) == 1) + expect_equal(length(tableNames(sqlContext)), 1) df <- tables(sqlContext) - expect_true(count(df) == 1) + expect_equal(count(df), 1) dropTempTable(sqlContext, "table1") }) @@ -258,8 +258,8 @@ test_that("registerTempTable() results in a queryable table and sql() results in df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") - expect_true(inherits(newdf, "DataFrame")) - expect_true(count(newdf) == 1) + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) dropTempTable(sqlContext, "table1") }) @@ -279,14 +279,14 @@ test_that("insertInto() on a registered table", { registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlContext, "select * from table1")) == 5) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Michael") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlContext, "select * from table1")) == 2) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Bob") + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") dropTempTable(sqlContext, "table1") }) @@ -294,16 +294,16 @@ test_that("table() returns a new DataFrame", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") tabledf <- table(sqlContext, "table1") - expect_true(inherits(tabledf, "DataFrame")) - expect_true(count(tabledf) == 3) + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(count(testRDD) == 3) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { @@ -311,9 +311,9 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) - expect_true(inherits(unioned, "RDD")) - expect_true(SparkR:::getSerializedMode(unioned) == "byte") - expect_true(collect(unioned)[[2]]$name == "Andy") + expect_is(unioned, "RDD") + expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -333,16 +333,16 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) - expect_true(inherits(unionByte, "RDD")) - expect_true(SparkR:::getSerializedMode(unionByte) == "byte") - expect_true(collect(unionByte)[[1]] == 1) - expect_true(collect(unionByte)[[12]]$name == "Andy") + expect_is(unionByte, "RDD") + expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) - expect_true(inherits(unionString, "RDD")) - expect_true(SparkR:::getSerializedMode(unionString) == "byte") - expect_true(collect(unionString)[[1]] == "Michael") - expect_true(collect(unionString)[[5]]$name == "Andy") + expect_is(unionString, "RDD") + expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { @@ -352,7 +352,7 @@ test_that("objectFile() works with row serialization", { saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) - expect_true(inherits(objectIn, "RDD")) + expect_is(objectIn, "RDD") expect_equal(SparkR:::getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) @@ -363,32 +363,32 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { row$newCol <- row$age + 5 row }) - expect_true(inherits(testRDD, "RDD")) + expect_is(testRDD, "RDD") collected <- collect(testRDD) - expect_true(collected[[1]]$name == "Michael") - expect_true(collected[[2]]$newCol == "35") + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) }) test_that("collect() returns a data.frame", { df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) - expect_true(names(rdf)[1] == "age") - expect_true(nrow(rdf) == 3) - expect_true(ncol(rdf) == 2) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) }) test_that("limit() returns DataFrame with the correct number of rows", { df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) - expect_true(inherits(dfLimited, "DataFrame")) - expect_true(count(dfLimited) == 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { df <- jsonFile(sqlContext, jsonPath) - expect_true(nrow(collect(df)) == nrow(take(df, 10))) - expect_true(ncol(collect(df)) == ncol(take(df, 10))) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { @@ -401,9 +401,9 @@ test_that("multiple pipeline transformations starting with a DataFrame result in row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE row }) - expect_true(inherits(second, "RDD")) - expect_true(count(second) == 3) - expect_true(collect(second)[[2]]$age == 35) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) expect_true(collect(second)[[2]]$testCol) expect_false(collect(second)[[3]]$testCol) }) @@ -430,36 +430,36 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) - expect_true(length(testSchema$fields()) == 2) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") - expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") - expect_true(testSchema$fields()[[1]]$name() == "age") + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") testTypes <- dtypes(df) - expect_true(length(testTypes[[1]]) == 2) - expect_true(testTypes[[1]][1] == "age") + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") testCols <- columns(df) - expect_true(length(testCols) == 2) - expect_true(testCols[2] == "name") + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") testNames <- names(df) - expect_true(length(testNames) == 2) - expect_true(testNames[2] == "name") + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") }) test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) - expect_true(nrow(testHead) == 3) - expect_true(ncol(testHead) == 2) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) testHead2 <- head(df, 2) - expect_true(nrow(testHead2) == 2) - expect_true(ncol(testHead2) == 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) testFirst <- first(df) - expect_true(nrow(testFirst) == 1) + expect_equal(nrow(testFirst), 1) }) test_that("distinct() on DataFrames", { @@ -472,15 +472,15 @@ test_that("distinct() on DataFrames", { df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) - expect_true(inherits(uniques, "DataFrame")) - expect_true(count(uniques) == 3) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) }) test_that("sample on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) - expect_true(inherits(sampled, "DataFrame")) + expect_is(sampled, "DataFrame") sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) @@ -491,15 +491,15 @@ test_that("sample on a DataFrame", { test_that("select operators", { df <- select(jsonFile(sqlContext, jsonPath), "name", "age") - expect_true(inherits(df$name, "Column")) - expect_true(inherits(df[[2]], "Column")) - expect_true(inherits(df[["age"]], "Column")) + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") - expect_true(inherits(df[,1], "DataFrame")) + expect_is(df[,1], "DataFrame") expect_equal(columns(df[,1]), c("name")) expect_equal(columns(df[,"age"]), c("age")) df2 <- df[,c("age", "name")] - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(columns(df2), c("age", "name")) df$age2 <- df$age @@ -518,50 +518,50 @@ test_that("select operators", { test_that("select with column", { df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") - expect_true(columns(df1) == c("name")) - expect_true(count(df1) == 3) + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) df2 <- select(df, df$age) - expect_true(columns(df2) == c("age")) - expect_true(count(df2) == 3) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) }) test_that("selectExpr() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") - expect_true(names(selected) == "(age * 2)") + expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) selected2 <- selectExpr(df, "name as newName", "abs(age) as age") expect_equal(names(selected2), c("newName", "age")) - expect_true(count(selected2) == 3) + expect_equal(count(selected2), 3) }) test_that("column calculation", { df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_true(names(d) == c("age2")) + expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) # Check if we can apply a user defined schema schema <- structType(structField("name", type = "string"), structField("age", type = "double")) df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df1, "DataFrame")) + expect_is(df1, "DataFrame") expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) # Run the same with loadDF df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) @@ -569,8 +569,8 @@ test_that("write.df() as parquet file", { df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", mode="overwrite") df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("test HiveContext", { @@ -580,17 +580,17 @@ test_that("test HiveContext", { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) df2 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") saveAsTable(df, "json", "json", "append", path = jsonPath2) df3 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df3, "DataFrame")) - expect_true(count(df3) == 6) + expect_is(df3, "DataFrame") + expect_equal(count(df3), 6) }) test_that("column operators", { @@ -643,65 +643,65 @@ test_that("string operators", { test_that("group by", { df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) expect_equal(columns(df1), c("age2")) gd <- groupBy(df, "name") - expect_true(inherits(gd, "GroupedData")) + expect_is(gd, "GroupedData") df2 <- count(gd) - expect_true(inherits(df2, "DataFrame")) - expect_true(3 == count(df2)) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) # Also test group_by, summarize, mean gd1 <- group_by(df, "name") - expect_true(inherits(gd1, "GroupedData")) + expect_is(gd1, "GroupedData") df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_true(inherits(df_summarized, "DataFrame")) - expect_true(3 == count(df_summarized)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) df3 <- agg(gd, age = "sum") - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) df3 <- agg(gd, age = sum(df$age)) - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) expect_equal(columns(df3), c("name", "age")) df4 <- sum(gd, "age") - expect_true(inherits(df4, "DataFrame")) - expect_true(3 == count(df4)) - expect_true(3 == count(mean(gd, "age"))) - expect_true(3 == count(max(gd, "age"))) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(3, count(mean(gd, "age"))) + expect_equal(3, count(max(gd, "age"))) }) test_that("arrange() and orderBy() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) - expect_true(collect(sorted)[1,2] == "Michael") + expect_equal(collect(sorted)[1,2], "Michael") sorted2 <- arrange(df, "name") - expect_true(collect(sorted2)[2,"age"] == 19) + expect_equal(collect(sorted2)[2,"age"], 19) sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) - expect_true(collect(sorted3)[2, "age"] == 19) + expect_equal(collect(sorted3)[2, "age"], 19) sorted4 <- orderBy(df, desc(df$name)) - expect_true(first(sorted4)$name == "Michael") - expect_true(collect(sorted4)[3,"name"] == "Andy") + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") }) test_that("filter() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") - expect_true(count(filtered) == 1) - expect_true(collect(filtered)$name == "Andy") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") filtered2 <- where(df, df$name != "Michael") - expect_true(count(filtered2) == 2) - expect_true(collect(filtered2)$age[2] == 19) + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) # test suites for %in% filtered3 <- filter(df, "age in (19)") @@ -727,29 +727,29 @@ test_that("join() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) - expect_true(count(joined) == 12) + expect_equal(count(joined), 12) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_true(count(joined2) == 3) + expect_equal(count(joined2), 3) joined3 <- join(df, df2, df$name == df2$name, "right_outer") expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_true(count(joined3) == 4) + expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) joined4 <- select(join(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(joined4), c("newAge", "name", "test")) - expect_true(count(joined4) == 4) + expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_is(testRDD, "RDD") + expect_equal(SparkR:::getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) @@ -775,50 +775,50 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(unioned) == 6) - expect_true(first(unioned)$name == "Michael") + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") excepted <- arrange(except(df, df2), desc(df$age)) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(excepted) == 2) - expect_true(first(excepted)$name == "Justin") + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") intersected <- arrange(intersect(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(intersected) == 1) - expect_true(first(intersected)$name == "Andy") + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") }) test_that("withColumn() and withColumnRenamed()", { df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("mutate() and rename()", { df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- rename(df, newerAge = df$age) - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("write.df() on DataFrame and works with parquetFile", { df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath) - expect_true(inherits(parquetDF, "DataFrame")) + expect_is(parquetDF, "DataFrame") expect_equal(count(df), count(parquetDF)) }) @@ -828,8 +828,8 @@ test_that("parquetFile works with multiple input paths", { parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) - expect_true(inherits(parquetDF, "DataFrame")) - expect_true(count(parquetDF) == count(df)*2) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df)*2) }) test_that("describe() on a DataFrame", { @@ -851,58 +851,58 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$name),] actual <- collect(dropna(df, cols = "name")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age),] actual <- collect(dropna(df, cols = "age")) row.names(expected) <- row.names(actual) # identical on two dataframes does not work here. Don't know why. # use identical on all columns as a workaround. - expect_true(identical(expected$age, actual$age)) - expect_true(identical(expected$height, actual$height)) - expect_true(identical(expected$name, actual$name)) + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with how expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] actual <- collect(dropna(df, "all")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df, "any")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, "any", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height),] actual <- collect(dropna(df, "all", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with threshold expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + as.integer(!is.na(rows$name)) >= 3,] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) test_that("fillna() on a DataFrame", { @@ -915,22 +915,22 @@ test_that("fillna() on a DataFrame", { expected$age[is.na(expected$age)] <- 50 expected$height[is.na(expected$height)] <- 50.6 actual <- collect(fillna(df, 50.6)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$age[is.na(expected$age)] <- 50 actual <- collect(fillna(df, 50.6, "age")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown", c("age", "name"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # fill with named list @@ -939,7 +939,7 @@ test_that("fillna() on a DataFrame", { expected$height[is.na(expected$height)] <- 50.6 expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R index c5eb417b40159..c2c724cdc762f 100644 --- a/R/pkg/inst/tests/test_take.R +++ b/R/pkg/inst/tests/test_take.R @@ -59,8 +59,8 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_true(length(take(strListRDD, 0)) == 0) - expect_true(length(take(strVectorRDD, 0)) == 0) - expect_true(length(take(numListRDD, 0)) == 0) - expect_true(length(take(numVectorRDD, 0)) == 0) + expect_equal(length(take(strListRDD, 0)), 0) + expect_equal(length(take(strVectorRDD, 0)), 0) + expect_equal(length(take(numListRDD, 0)), 0) + expect_equal(length(take(numVectorRDD, 0)), 0) }) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 092ad9dc10c2e..58318dfef71ab 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -27,9 +27,9 @@ test_that("textFile() on a local file returns an RDD", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(count(rdd) > 0) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName) }) @@ -133,7 +133,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1) unlink(fileName2) diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index 15030e6f1d77e..aa0d2a66b9082 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -45,10 +45,10 @@ test_that("serializeToBytes on RDD", { writeLines(mockFile, fileName) text.rdd <- textFile(sc, fileName) - expect_true(getSerializedMode(text.rdd) == "string") + expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) expect_equal(collect(ser.rdd), as.list(mockFile)) - expect_true(getSerializedMode(ser.rdd) == "byte") + expect_equal(getSerializedMode(ser.rdd), "byte") unlink(fileName) }) From 4137f769b84300648ad933b0b3054d69a7316745 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 10:30:54 -0700 Subject: [PATCH 076/274] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types. This patch doesn't actually introduce any code that uses the new ExpectsInputTypes. It just adds the trait so others can use it. Also renamed the old expectsInputTypes function to just inputTypes. We should add implicit type casting also in the future. Author: Reynold Xin Closes #7151 from rxin/expects-input-types and squashes the following commits: 16cf07b [Reynold Xin] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 - .../catalyst/analysis/HiveTypeCoercion.scala | 8 ++--- .../sql/catalyst/expressions/Expression.scala | 29 ++++++++++++++++--- .../spark/sql/catalyst/expressions/math.scala | 6 ++-- .../spark/sql/catalyst/expressions/misc.scala | 8 ++--- .../sql/catalyst/expressions/predicates.scala | 6 ++-- .../expressions/stringOperations.scala | 10 +++---- 7 files changed, 44 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a069b4710f38c..583338da57117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.types._ * Throws user facing errors when passed invalid queries that fail to analyze. */ trait CheckAnalysis { - self: Analyzer => /** * Override to provide additional checks for correct analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a9d396d1faeeb..2ab5cb666fbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -45,7 +45,7 @@ object HiveTypeCoercion { IfCoercion :: Division :: PropagateTypes :: - AddCastForAutoCastInputTypes :: + ImplicitTypeCasts :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -705,13 +705,13 @@ object HiveTypeCoercion { * Casts types according to the expected input types for Expressions that have the trait * [[AutoCastInputTypes]]. */ - object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { + object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => - val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes => + val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map { case (child, actual, expected) => if (actual == expected) child else Cast(child, expected) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b5063f32fa529..e18a3118945e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -265,17 +265,38 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + /** * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. */ -trait AutoCastInputTypes { - self: Expression => +trait AutoCastInputTypes { self: Expression => - def expectedChildTypes: Seq[DataType] + def inputTypes: Seq[DataType] override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, // so type mismatch error won't be reported here, but for underling `Cast`s. TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index da63f2fa970cf..b51318dd5044c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) extends UnaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) override def toString: String = s"$name($left, $right)" @@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia case class Bin(child: Expression) extends UnaryExpression with Serializable with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(LongType) + override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a7bcbe46c339a..407023e472081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -36,7 +36,7 @@ case class Md5(child: Expression) override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression) override def toString: String = s"SHA2($left, $right)" - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) @@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -179,7 +179,7 @@ case class Crc32(child: Expression) override def dataType: DataType = LongType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 98cd5aa8148c4..a777f77add2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -72,7 +72,7 @@ trait PredicateHelper { case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { override def toString: String = s"NOT $child" - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType) override def eval(input: InternalRow): Any = { child.eval(input) match { @@ -122,7 +122,7 @@ case class InSet(value: Expression, hset: Set[Any]) case class And(left: Expression, right: Expression) extends BinaryExpression with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "&&" @@ -171,7 +171,7 @@ case class And(left: Expression, right: Expression) case class Or(left: Expression, right: Expression) extends BinaryExpression with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "||" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index ce184e4f32f18..4cbfc4e084948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val evaluated = child.eval(input) @@ -165,7 +165,7 @@ trait StringComparison extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def eval(input: InternalRow): Any = { val leftEval = left.eval(input) @@ -238,7 +238,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) if (str.dataType == BinaryType) str.dataType else StringType } - override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil @@ -297,7 +297,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) */ case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val string = child.eval(input) From 31b4a3d7f2be9053a041e5ae67418562a93d80d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Jul 2015 10:31:35 -0700 Subject: [PATCH 077/274] [SPARK-8621] [SQL] support empty string as column name improve the empty check in `parseAttributeName` so that we can allow empty string as column name. Close https://github.com/apache/spark/pull/7117 Author: Wenchen Fan Closes #7149 from cloud-fan/8621 and squashes the following commits: efa9e3e [Wenchen Fan] support empty string --- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 4 ++-- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b009a200b920f..e911b907e8536 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -161,7 +161,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { if (tmp.nonEmpty) throw e inBacktick = true } else if (char == '.') { - if (tmp.isEmpty) throw e + if (name(i - 1) == '.' || i == name.length - 1) throw e nameParts += tmp.mkString tmp.clear() } else { @@ -170,7 +170,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } i += 1 } - if (tmp.isEmpty || inBacktick) throw e + if (inBacktick) throw e nameParts += tmp.mkString nameParts.toSeq } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50d324c0686fa..afb1cf5f8d1cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -730,4 +730,11 @@ class DataFrameSuite extends QueryTest { val res11 = ctx.range(-1).select("id") assert(res11.count == 0) } + + test("SPARK-8621: support empty string column name") { + val df = Seq(Tuple1(1)).toDF("").as("t") + // We should allow empty string as column name + df.col("") + df.col("t.``") + } } From 184de91d15a4bfc5c014e8cf86211874bba4593f Mon Sep 17 00:00:00 2001 From: lewuathe Date: Wed, 1 Jul 2015 11:14:07 -0700 Subject: [PATCH 078/274] [SPARK-6263] [MLLIB] Python MLlib API missing items: Utils Implement missing API in pyspark. MLUtils * appendBias * loadVectors `kFold` is also missing however I am not sure `ClassTag` can be passed or restored through python. Author: lewuathe Closes #5707 from Lewuathe/SPARK-6263 and squashes the following commits: 16863ea [lewuathe] Merge master 3fc27e7 [lewuathe] Merge branch 'master' into SPARK-6263 6084e9c [lewuathe] Resolv conflict d2aa2a0 [lewuathe] Resolv conflict 9c329d8 [lewuathe] Fix efficiency 3a12a2d [lewuathe] Merge branch 'master' into SPARK-6263 1d4714b [lewuathe] Fix style b29e2bc [lewuathe] Remove scipy dependencies e32eb40 [lewuathe] Merge branch 'master' into SPARK-6263 25d3c9d [lewuathe] Remove unnecessary imports 7ec04db [lewuathe] Resolv conflict 1502d13 [lewuathe] Resolv conflict d6bd416 [lewuathe] Check existence of scipy.sparse 5d555b1 [lewuathe] Construct scipy.sparse matrix c345a44 [lewuathe] Merge branch 'master' into SPARK-6263 b8b5ef7 [lewuathe] Fix unnecessary sort method d254be7 [lewuathe] Merge branch 'master' into SPARK-6263 62a9c7e [lewuathe] Fix appendBias return type 454c73d [lewuathe] Merge branch 'master' into SPARK-6263 a353354 [lewuathe] Remove unnecessary appendBias implementation 44295c2 [lewuathe] Merge branch 'master' into SPARK-6263 64f72ad [lewuathe] Merge branch 'master' into SPARK-6263 c728046 [lewuathe] Fix style 2980569 [lewuathe] [SPARK-6263] Python MLlib API missing items: Utils --- .../mllib/api/python/PythonMLLibAPI.scala | 9 ++++ python/pyspark/mllib/tests.py | 43 +++++++++++++++++++ python/pyspark/mllib/util.py | 22 ++++++++++ 3 files changed, 74 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index a66a404d5c846..458fab48fef5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -75,6 +75,15 @@ private[python] class PythonMLLibAPI extends Serializable { minPartitions: Int): JavaRDD[LabeledPoint] = MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) + /** + * Loads and serializes vectors saved with `RDD#saveAsTextFile`. + * @param jsc Java SparkContext + * @param path file or directory path in any Hadoop-supported file system URI + * @return serialized vectors in a RDD + */ + def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] = + MLUtils.loadVectors(jsc.sc, path) + private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f0091d6faccce..49ce125de7e78 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -54,6 +54,7 @@ from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler, ElementwiseProduct from pyspark.mllib.util import LinearDataGenerator +from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext @@ -1290,6 +1291,48 @@ def func(rdd): self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2) +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [2.0, 2.0, 2.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_vector(self): + data = Vectors.dense([2.0, 2.0, 2.0]) + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_sp_vector(self): + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) + # Returned value must be SparseVector + ret = MLUtils.appendBias(data) + self.assertEqual(ret, expected) + self.assertEqual(type(ret), SparseVector) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + temp_dir = tempfile.mkdtemp() + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") + try: + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + except: + self.fail() + finally: + shutil.rmtree(load_vectors_path) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 348238319e407..875d3b2d642c6 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -169,6 +169,28 @@ def loadLabeledPoints(sc, path, minPartitions=None): minPartitions = minPartitions or min(sc.defaultParallelism, 2) return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) + @staticmethod + def appendBias(data): + """ + Returns a new vector with `1.0` (bias) appended to + the end of the input vector. + """ + vec = _convert_to_vector(data) + if isinstance(vec, SparseVector): + newIndices = np.append(vec.indices, len(vec)) + newValues = np.append(vec.values, 1.0) + return SparseVector(len(vec) + 1, newIndices, newValues) + else: + return _convert_to_vector(np.append(vec.toArray(), 1.0)) + + @staticmethod + def loadVectors(sc, path): + """ + Loads vectors saved using `RDD[Vector].saveAsTextFile` + with the default number of partitions. + """ + return callMLlibFunc("loadVectors", sc, path) + class Saveable(object): """ From 2012913355993e6516e4c81dbc92e579977131da Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 1 Jul 2015 11:17:56 -0700 Subject: [PATCH 079/274] [SPARK-8308] [MLLIB] add missing save load for python example jira: https://issues.apache.org/jira/browse/SPARK-8308 1. add some missing save/load in python examples. , LogisticRegression, LinearRegression and NaiveBayes 2. tune down iterations for MatrixFactorization, since current number will trigger StackOverflow for default java configuration (>1M) Author: Yuhao Yang Closes #6760 from hhbyyh/docUpdate and squashes the following commits: 9bd3383 [Yuhao Yang] update scala example 8a44692 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into docUpdate 077cbb8 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into docUpdate 3e948dc [Yuhao Yang] add missing save load for python example --- docs/mllib-collaborative-filtering.md | 6 +++--- docs/mllib-linear-methods.md | 12 ++++++++++-- docs/mllib-naive-bayes.md | 6 +++++- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index dfdf6216b270c..eedc23424ad54 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -77,7 +77,7 @@ val ratings = data.map(_.split(',') match { case Array(user, item, rate) => // Build the recommendation model using ALS val rank = 10 -val numIterations = 20 +val numIterations = 10 val model = ALS.train(ratings, rank, numIterations, 0.01) // Evaluate the model on rating data @@ -149,7 +149,7 @@ public class CollaborativeFiltering { // Build the recommendation model using ALS int rank = 10; - int numIterations = 20; + int numIterations = 10; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data @@ -210,7 +210,7 @@ ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l # Build the recommendation model using Alternating Least Squares rank = 10 -numIterations = 20 +numIterations = 10 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 2a2a7c13186d8..3927d65fbf8fb 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -499,7 +499,7 @@ Note that the Python API does not yet support multiclass classification and mode will in the future. {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint from numpy import array @@ -518,6 +518,10 @@ model = LogisticRegressionWithLBFGS.train(parsedData) labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LogisticRegressionModel.load(sc, "myModelPath") {% endhighlight %}
@@ -668,7 +672,7 @@ values. We compute the mean squared error at the end to evaluate Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel from numpy import array # Load and parse the data @@ -686,6 +690,10 @@ model = LinearRegressionWithSGD.train(parsedData) valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %} diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index bf6d124fd5d8d..e73bd30f3a90a 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -119,7 +119,7 @@ used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors from pyspark.mllib.regression import LabeledPoint @@ -140,6 +140,10 @@ model = NaiveBayes.train(training, 1.0) # Make prediction and test accuracy. predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + +# Save and load model +model.save(sc, "myModelPath") +sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} From b8faa32875aa560cdce340266d898902a920418d Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 1 Jul 2015 11:57:52 -0700 Subject: [PATCH 080/274] [SPARK-8765] [MLLIB] [PYTHON] removed flaky python PIC test See failure: [https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/36133/console] CC yanboliang mengxr Author: Joseph K. Bradley Closes #7164 from jkbradley/pic-python-test and squashes the following commits: 156d55b [Joseph K. Bradley] removed flaky python PIC test --- python/pyspark/mllib/clustering.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e3c8a24c4a751..a3eab635282f6 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -288,16 +288,12 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = PowerIterationClustering.train(rdd, 2, 100) >>> model.k 2 - >>> sorted(model.assignments().collect()) - [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = PowerIterationClusteringModel.load(sc, path) >>> sameModel.k 2 - >>> sorted(sameModel.assignments().collect()) - [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... >>> from shutil import rmtree >>> try: ... rmtree(path) From 75b9fe4c5ff6f206c6fc9100563d625b39f142ba Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 1 Jul 2015 11:59:24 -0700 Subject: [PATCH 081/274] [SPARK-8378] [STREAMING] Add the Python API for Flume Author: zsxwing Closes #6830 from zsxwing/flume-python and squashes the following commits: 78dfdac [zsxwing] Fix the compile error in the test code f1bf3c0 [zsxwing] Address TD's comments 0449723 [zsxwing] Add sbt goal streaming-flume-assembly/assembly e93736b [zsxwing] Fix the test case for determine_modules_to_test 9d5821e [zsxwing] Fix pyspark_core dependencies f9ee681 [zsxwing] Merge branch 'master' into flume-python 7a55837 [zsxwing] Add streaming_flume_assembly to run-tests.py b96b0de [zsxwing] Merge branch 'master' into flume-python ce85e83 [zsxwing] Fix incompatible issues for Python 3 01cbb3d [zsxwing] Add import sys 152364c [zsxwing] Fix the issue that StringIO doesn't work in Python 3 14ba0ff [zsxwing] Add flume-assembly for sbt building b8d5551 [zsxwing] Merge branch 'master' into flume-python 4762c34 [zsxwing] Fix the doc 0336579 [zsxwing] Refactor Flume unit tests and also add tests for Python API 9f33873 [zsxwing] Add the Python API for Flume --- dev/run-tests.py | 7 +- dev/sparktestsupport/modules.py | 15 +- docs/streaming-flume-integration.md | 18 ++ docs/streaming-programming-guide.md | 2 +- .../main/python/streaming/flume_wordcount.py | 55 +++++ external/flume-assembly/pom.xml | 135 +++++++++++ .../streaming/flume/FlumeTestUtils.scala | 116 ++++++++++ .../spark/streaming/flume/FlumeUtils.scala | 76 ++++++- .../flume/PollingFlumeTestUtils.scala | 209 ++++++++++++++++++ .../flume/FlumePollingStreamSuite.scala | 173 +++------------ .../streaming/flume/FlumeStreamSuite.scala | 106 ++------- pom.xml | 1 + project/SparkBuild.scala | 6 +- python/pyspark/streaming/flume.py | 147 ++++++++++++ python/pyspark/streaming/tests.py | 179 ++++++++++++++- 15 files changed, 1009 insertions(+), 236 deletions(-) create mode 100644 examples/src/main/python/streaming/flume_wordcount.py create mode 100644 external/flume-assembly/pom.xml create mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala create mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala create mode 100644 python/pyspark/streaming/flume.py diff --git a/dev/run-tests.py b/dev/run-tests.py index 4596e07014733..1f0d218514f92 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -96,8 +96,8 @@ def determine_modules_to_test(changed_modules): ['examples', 'graphx'] >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) >>> x # doctest: +NORMALIZE_WHITESPACE - ['examples', 'hive-thriftserver', 'mllib', 'pyspark-core', 'pyspark-ml', \ - 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming', 'sparkr', 'sql'] + ['examples', 'hive-thriftserver', 'mllib', 'pyspark-ml', \ + 'pyspark-mllib', 'pyspark-sql', 'sparkr', 'sql'] """ # If we're going to have to run all of the tests, then we can just short-circuit # and return 'root'. No module depends on root, so if it appears then it will be @@ -293,7 +293,8 @@ def build_spark_sbt(hadoop_version): build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["package", "assembly/assembly", - "streaming-kafka-assembly/assembly"] + "streaming-kafka-assembly/assembly", + "streaming-flume-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index efe3a897e9c10..993583e2f4119 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -203,7 +203,7 @@ def contains_file(self, filename): streaming_flume = Module( - name="streaming_flume", + name="streaming-flume", dependencies=[streaming], source_file_regexes=[ "external/flume", @@ -214,6 +214,15 @@ def contains_file(self, filename): ) +streaming_flume_assembly = Module( + name="streaming-flume-assembly", + dependencies=[streaming_flume, streaming_flume_sink], + source_file_regexes=[ + "external/flume-assembly", + ] +) + + mllib = Module( name="mllib", dependencies=[streaming, sql], @@ -241,7 +250,7 @@ def contains_file(self, filename): pyspark_core = Module( name="pyspark-core", - dependencies=[mllib, streaming, streaming_kafka], + dependencies=[], source_file_regexes=[ "python/(?!pyspark/(ml|mllib|sql|streaming))" ], @@ -281,7 +290,7 @@ def contains_file(self, filename): pyspark_streaming = Module( name="pyspark-streaming", - dependencies=[pyspark_core, streaming, streaming_kafka], + dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly], source_file_regexes=[ "python/pyspark/streaming" ], diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 8d6e74370918f..de0461010daec 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -58,6 +58,15 @@ configuring Flume agents. See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). +
+ from pyspark.streaming.flume import FlumeUtils + + flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). +
Note that the hostname should be the same as the one used by the resource manager in the @@ -135,6 +144,15 @@ configuring Flume agents. JavaReceiverInputDStreamflumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
+ from pyspark.streaming.flume import FlumeUtils + + addresses = [([sink machine hostname 1], [sink port 1]), ([sink machine hostname 2], [sink port 2])] + flumeStream = FlumeUtils.createPollingStream(streamingContext, addresses) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). +
See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index b784d59666fec..e72d5580dae55 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py new file mode 100644 index 0000000000000..091b64d8c4af4 --- /dev/null +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: flume_wordcount.py + + To run this on your local machine, you need to setup Flume first, see + https://flume.apache.org/documentation.html + + and then run the example + `$ bin/spark-submit --jars external/flume-assembly/target/scala-*/\ + spark-streaming-flume-assembly-*.jar examples/src/main/python/streaming/flume_wordcount.py \ + localhost 12345 +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.flume import FlumeUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: flume_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingFlumeWordCount") + ssc = StreamingContext(sc, 1) + + hostname, port = sys.argv[1:] + kvs = FlumeUtils.createStream(ssc, hostname, int(port)) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml new file mode 100644 index 0000000000000..8565cd83edfa2 --- /dev/null +++ b/external/flume-assembly/pom.xml @@ -0,0 +1,135 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-flume-assembly_2.10 + jar + Spark Project External Flume Assembly + http://spark.apache.org/ + + + streaming-flume-assembly + + + + + org.apache.spark + spark-streaming-flume_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + org.apache.avro + avro + ${avro.version} + + + org.apache.avro + avro-ipc + ${avro.version} + + + io.netty + netty + + + org.mortbay.jetty + jetty + + + org.mortbay.jetty + jetty-util + + + org.mortbay.jetty + servlet-api + + + org.apache.velocity + velocity + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala new file mode 100644 index 0000000000000..9d9c3b189415f --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.net.{InetSocketAddress, ServerSocket} +import java.nio.ByteBuffer +import java.util.{List => JList} + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets.UTF_8 +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils +import org.apache.flume.source.avro +import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} + +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class FlumeTestUtils { + + private var transceiver: NettyTransceiver = null + + private val testPort: Int = findFreePort() + + def getTestPort(): Int = testPort + + /** Find a free port */ + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + /** Send data to the flume receiver */ + def writeInput(input: JList[String], enableCompression: Boolean): Unit = { + val testAddress = new InetSocketAddress("localhost", testPort) + + val inputEvents = input.map { item => + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) + event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + event + } + + // if last attempted transceiver had succeeded, close it + close() + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + if (client == null) { + throw new AssertionError("Cannot create client") + } + + // Send data + val status = client.appendBatch(inputEvents.toList) + if (status != avro.Status.OK) { + throw new AssertionError("Sent events unsuccessfully") + } + } + + def close(): Unit = { + if (transceiver != null) { + transceiver.close() + transceiver = null + } + } + + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) + } + } + +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 44dec45c227ca..095bfb0c73a9a 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -18,10 +18,16 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.io.{DataOutputStream, ByteArrayOutputStream} +import java.util.{List => JList, Map => JMap} +import scala.collection.JavaConversions._ + +import org.apache.spark.api.java.function.PairFunction +import org.apache.spark.api.python.PythonRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -236,3 +242,71 @@ object FlumeUtils { createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) } } + +/** + * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's FlumeUtils. + */ +private class FlumeUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + + def createPollingStream( + jssc: JavaStreamingContext, + hosts: JList[String], + ports: JList[Int], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + assert(hosts.length == ports.length) + val addresses = hosts.zip(ports).map { + case (host, port) => new InetSocketAddress(host, port) + } + val dstream = FlumeUtils.createPollingStream( + jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + +} + +private object FlumeUtilsPythonHelper { + + private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val output = new DataOutputStream(byteStream) + try { + output.writeInt(map.size) + map.foreach { kv => + PythonRDD.writeUTF(kv._1.toString, output) + PythonRDD.writeUTF(kv._2.toString, output) + } + byteStream.toByteArray + } + finally { + output.close() + } + } + + private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): + JavaPairDStream[Array[Byte], Array[Byte]] = { + dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { + override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { + val event = sparkEvent.event + val byteBuffer = event.getBody + val body = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(body) + (stringMapToByteArray(event.getHeaders), body) + } + }) + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala new file mode 100644 index 0000000000000..91d63d49dbec3 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.util.concurrent._ +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Charsets.UTF_8 +import org.apache.flume.event.EventBuilder +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables + +import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class PollingFlumeTestUtils { + + private val batchCount = 5 + val eventsPerBatch = 100 + private val totalEventsPerChannel = batchCount * eventsPerBatch + private val channelCapacity = 5000 + + def getTotalEvents: Int = totalEventsPerChannel * channels.size + + private val channels = new ArrayBuffer[MemoryChannel] + private val sinks = new ArrayBuffer[SparkSink] + + /** + * Start a sink and return the port of this sink + */ + def startSingleSink(): Int = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + channels += (channel) + sinks += sink + + sink.getPort() + } + + /** + * Start 2 sinks and return the ports + */ + def startMultipleSinks(): JList[Int] = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + + sinks += sink + sinks += sink2 + channels += channel + channels += channel2 + + sinks.map(_.getPort()) + } + + /** + * Send data and wait until all data has been received + */ + def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { + executorCompletion.submit(new TxnSubmitter(channel)) + }) + + for (i <- 0 until channels.size) { + executorCompletion.take() + } + + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + } + + /** + * A Python-friendly method to assert the output + */ + def assertOutput( + outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { + require(outputHeaders.size == outputBodies.size) + val eventSize = outputHeaders.size + if (eventSize != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") + } + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventBodyToVerify = s"${channels(k).getName}-$i" + val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") + var found = false + var j = 0 + while (j < eventSize && !found) { + if (eventBodyToVerify == outputBodies.get(j) && + eventHeaderToVerify == outputHeaders.get(j)) { + found = true + counter += 1 + } + j += 1 + } + } + if (counter != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") + } + } + + def assertChannelsAreEmpty(): Unit = { + channels.foreach(assertChannelIsEmpty) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { + throw new AssertionError(s"Channel ${channel.getName} is not empty") + } + } + + def close(): Unit = { + sinks.foreach(_.stop()) + sinks.clear() + channels.foreach(_.stop()) + channels.clear() + } + + private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), + Map[String, String](s"test-$t" -> "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + } + null + } + } + +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index d772b9ca9b570..d5f9a0aa38f9f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,47 +18,33 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables -import org.apache.flume.event.EventBuilder -import org.scalatest.concurrent.Eventually._ - +import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} -import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val batchCount = 5 - val eventsPerBatch = 100 - val totalEventsPerChannel = batchCount * eventsPerBatch - val channelCapacity = 5000 val maxAttempts = 5 val batchDuration = Seconds(1) val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - def beforeFunction() { - logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - } - - before(beforeFunction()) + val utils = new PollingFlumeTestUtils test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -89,146 +75,55 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log } private def testFlumePolling(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - writeAndVerify(Seq(sink), Seq(channel)) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() + try { + val port = utils.startSingleSink() + + writeAndVerify(Seq(port)) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } } private def testFlumePollingMultipleHost(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() try { - writeAndVerify(Seq(sink, sink2), Seq(channel, channel2)) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) + val ports = utils.startMultipleSinks() + writeAndVerify(ports) + utils.assertChannelsAreEmpty() } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() + utils.close() } } - def writeAndVerify(sinks: Seq[SparkSink], channels: Seq[MemoryChannel]) { + def writeAndVerify(sinkPorts: Seq[Int]): Unit = { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = sinks.map(sink => new InetSocketAddress("localhost", sink.getPort())) + val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) + utils.eventsPerBatch, 5) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - - val latch = new CountDownLatch(batchCount * channels.size) - sinks.foreach(_.countdownWhenBatchReceived(latch)) - - channels.foreach(channel => { - executorCompletion.submit(new TxnSubmitter(channel, clock)) - }) - - for (i <- 0 until channels.size) { - executorCompletion.take() - } - - latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. - clock.advance(batchDuration.milliseconds) - - // The eventually is required to ensure that all data in the batch has been processed. - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 - } - j += 1 - } - } - assert(counter === totalEventsPerChannel * channels.size) - } - ssc.stop() - } - - def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) - } - - private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( - "utf-8"), - Map[String, String]("test-" + t.toString -> "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach + try { + utils.sendDatAndEnsureAllDataHasBeenReceived() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenOutputBuffer = outputBuffer.flatten + val headers = flattenOutputBuffer.map(_.event.getHeaders.map { + case kv => (kv._1.toString, kv._2.toString) + }).map(mapAsJavaMap) + val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + utils.assertOutput(headers, bodies) } - null + } finally { + ssc.stop() } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index c926359987d89..5bc4cdf65306c 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,20 +17,12 @@ package org.apache.spark.streaming.flume -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.commons.lang3.RandomUtils -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -41,22 +33,10 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.util.Utils class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - var transceiver: NettyTransceiver = null - - after { - if (ssc != null) { - ssc.stop() - } - if (transceiver != null) { - transceiver.close() - } - } test("flume input stream") { testFlumeStream(testCompression = false) @@ -69,19 +49,29 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w /** Run test on flume stream */ private def testFlumeStream(testCompression: Boolean): Unit = { val input = (1 to 100).map { _.toString } - val testPort = findFreePort() - val outputBuffer = startContext(testPort, testCompression) - writeAndVerify(input, testPort, outputBuffer, testCompression) - } + val utils = new FlumeTestUtils + try { + val outputBuffer = startContext(utils.getTestPort(), testCompression) - /** Find a free port */ - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, conf)._2 + eventually(timeout(10 seconds), interval(100 milliseconds)) { + utils.writeInput(input, testCompression) + } + + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + output should be (input) + } + } finally { + if (ssc != null) { + ssc.stop() + } + utils.close() + } } /** Setup and start the streaming context */ @@ -98,58 +88,6 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w outputBuffer } - /** Send data to the flume receiver and verify whether the data was received */ - private def writeAndVerify( - input: Seq[String], - testPort: Int, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], - enableCompression: Boolean - ) { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - event - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - // if last attempted transceiver had succeeded, close it - if (transceiver != null) { - transceiver.close() - transceiver = null - } - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - client should not be null - - // Send data - val status = client.appendBatch(inputEvents.toList) - status should be (avro.Status.OK) - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) - output should be (input) - } - } - /** Class to create socket channel with compression */ private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { diff --git a/pom.xml b/pom.xml index 94dd512cfb618..211da9ee74a3f 100644 --- a/pom.xml +++ b/pom.xml @@ -102,6 +102,7 @@ external/twitter external/flume external/flume-sink + external/flume-assembly external/mqtt external/zeromq examples diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f5f1c9a1a247a..4ef4dc8bdc039 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -347,7 +347,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-kafka-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py new file mode 100644 index 0000000000000..cbb573f226bbe --- /dev/null +++ b/python/pyspark/streaming/flume.py @@ -0,0 +1,147 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version >= "3": + from io import BytesIO +else: + from StringIO import StringIO +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int +from pyspark.streaming import DStream + +__all__ = ['FlumeUtils', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') + + +class FlumeUtils(object): + + @staticmethod + def createStream(ssc, hostname, port, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + enableDecompression=False, + bodyDecoder=utf8_decoder): + """ + Create an input stream that pulls events from Flume. + + :param ssc: StreamingContext object + :param hostname: Hostname of the slave machine to which the flume data will be sent + :param port: Port of the slave machine to which the flume data will be sent + :param storageLevel: Storage level to use for storing the received objects + :param enableDecompression: Should netty server decompress input stream + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def createPollingStream(ssc, addresses, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + maxBatchSize=1000, + parallelism=5, + bodyDecoder=utf8_decoder): + """ + Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + This stream will poll the sink for data and will pull events as they are available. + + :param ssc: StreamingContext object + :param addresses: List of (host, port)s on which the Spark Sink is running. + :param storageLevel: Storage level to use for storing the received objects + :param maxBatchSize: The maximum number of events to be pulled from the Spark sink + in a single RPC call + :param parallelism: Number of concurrent requests this stream should send to the sink. + Note that having a higher number of requests concurrently being pulled + will result in this stream using more threads + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + hosts = [] + ports = [] + for (host, port) in addresses: + hosts.append(host) + ports.append(port) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createPollingStream( + ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def _toPythonDStream(ssc, jstream, bodyDecoder): + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + + def func(event): + headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0]) + headers = {} + strSer = UTF8Deserializer() + for i in range(0, read_int(headersBytes)): + key = strSer.loads(headersBytes) + value = strSer.loads(headersBytes) + headers[key] = value + body = bodyDecoder(event[1]) + return (headers, body) + return stream.map(func) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Flume libraries not found in class path. Try one of the following. + + 1. Include the Flume library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 91ce681fbe169..188c8ff12067e 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -38,6 +38,7 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition +from pyspark.streaming.flume import FlumeUtils class PySparkStreamingTestCase(unittest.TestCase): @@ -677,7 +678,156 @@ def test_kafka_rdd_with_leaders(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) -if __name__ == "__main__": + +class FlumeStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(FlumeStreamTests, self).setUp() + + utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + super(FlumeStreamTests, self).tearDown() + + def _startContext(self, n, compressed): + # Start the StreamingContext and also collect the result + dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(), + enableDecompression=compressed) + result = [] + + def get_output(_, rdd): + for event in rdd.collect(): + if len(result) < n: + result.append(event) + dstream.foreachRDD(get_output) + self.ssc.start() + return result + + def _validateResult(self, input, result): + # Validate both the header and the body + header = {"test": "header"} + self.assertEqual(len(input), len(result)) + for i in range(0, len(input)): + self.assertEqual(header, result[i][0]) + self.assertEqual(input[i], result[i][1]) + + def _writeInput(self, input, compressed): + # Try to write input to the receiver until success or timeout + start_time = time.time() + while True: + try: + self._utils.writeInput(input, compressed) + break + except: + if time.time() - start_time < self.timeout: + time.sleep(0.01) + else: + raise + + def test_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), False) + self._writeInput(input, False) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + def test_compressed_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), True) + self._writeInput(input, True) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + +class FlumePollingStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + maxAttempts = 5 + + def setUp(self): + utilsClz = \ + self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + def _writeAndVerify(self, ports): + # Set up the streaming context and input streams + ssc = StreamingContext(self.sc, self.duration) + try: + addresses = [("localhost", port) for port in ports] + dstream = FlumeUtils.createPollingStream( + ssc, + addresses, + maxBatchSize=self._utils.eventsPerBatch(), + parallelism=5) + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + dstream.foreachRDD(get_output) + ssc.start() + self._utils.sendDatAndEnsureAllDataHasBeenReceived() + + self.wait_for(outputBuffer, self._utils.getTotalEvents()) + outputHeaders = [event[0] for event in outputBuffer] + outputBodies = [event[1] for event in outputBuffer] + self._utils.assertOutput(outputHeaders, outputBodies) + finally: + ssc.stop(False) + + def _testMultipleTimes(self, f): + attempt = 0 + while True: + try: + f() + break + except: + attempt += 1 + if attempt >= self.maxAttempts: + raise + else: + import traceback + traceback.print_exc() + + def _testFlumePolling(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def _testFlumePollingMultipleHosts(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def test_flume_polling(self): + self._testMultipleTimes(self._testFlumePolling) + + def test_flume_polling_multiple_hosts(self): + self._testMultipleTimes(self._testFlumePollingMultipleHosts) + + +def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") jars = glob.glob( @@ -692,5 +842,30 @@ def test_kafka_rdd_with_leaders(self): raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " "remove all but one") % kafka_assembly_dir) else: - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0] + return jars[0] + + +def search_flume_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") + jars = glob.glob( + os.path.join(flume_assembly_dir, "target/scala-*/spark-streaming-flume-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " + "remove all but one") % flume_assembly_dir) + else: + return jars[0] + +if __name__ == "__main__": + kafka_assembly_jar = search_kafka_assembly_jar() + flume_assembly_jar = search_flume_assembly_jar() + jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) + + os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() From 9f7db3486fcb403cae8da9dfce8978373c3f47b7 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 1 Jul 2015 12:33:24 -0700 Subject: [PATCH 082/274] [SPARK-7820] [BUILD] Fix Java8-tests suite compile and test error under sbt Author: jerryshao Closes #7120 from jerryshao/SPARK-7820 and squashes the following commits: 6902439 [jerryshao] fix Java8-tests suite compile error under sbt --- extras/java8-tests/pom.xml | 8 ++++++++ project/SparkBuild.scala | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index f138251748c9e..3636a9037d43f 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -49,6 +56,7 @@ spark-streaming_${scala.binary.version} ${project.version} test-jar + test junit diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4ef4dc8bdc039..5f389bcc9ceeb 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -161,7 +161,7 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExludedDependencies.settings ++ Revolver.settings)) + .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -246,7 +246,7 @@ object Flume { This excludes library dependencies in sbt, which are specified in maven but are not needed by sbt build. */ -object ExludedDependencies { +object ExcludedDependencies { lazy val settings = Seq( libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } ) From 3083e17645e4b707646fe48e406e02c156a0f37b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Jul 2015 12:39:57 -0700 Subject: [PATCH 083/274] [QUICKFIX] [SQL] fix copy of generated row copy() of generated Row doesn't check nullability of columns Author: Davies Liu Closes #7163 from davies/fix_copy and squashes the following commits: 661a206 [Davies Liu] fix copy of generated row --- .../sql/catalyst/expressions/codegen/GenerateProjection.scala | 2 +- .../spark/sql/catalyst/expressions/ExpressionEvalHelper.scala | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 5be47175fa7f1..3c7ee9cc16599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -148,7 +148,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n") val copyColumns = expressions.zipWithIndex.map { case (e, i) => - s"""arr[$i] = c$i;""" + s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 7d95ef7f710af..3171caf6ad77f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -136,6 +136,9 @@ trait ExpressionEvalHelper { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } + if (actual.copy() != expectedRow) { + fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") + } } protected def checkEvaluationWithOptimization( From 1ce6428907b4ddcf52dbf0c86196d82ab7392442 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 1 Jul 2015 20:40:47 +0100 Subject: [PATCH 084/274] [SPARK-3444] [CORE] Restore INFO level after log4j test. Otherwise other tests don't log anything useful... Author: Marcelo Vanzin Closes #7140 from vanzin/SPARK-3444 and squashes the following commits: de14836 [Marcelo Vanzin] Better fix. 6cff13a [Marcelo Vanzin] [SPARK-3444] [core] Restore INFO level after log4j test. --- .../scala/org/apache/spark/util/UtilsSuite.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index baa4c661cc21e..251a797dc28a2 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -486,11 +486,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // Test for using the util function to change our log levels. test("log4j log level change") { - Utils.setLogLevel(org.apache.log4j.Level.ALL) - assert(log.isInfoEnabled()) - Utils.setLogLevel(org.apache.log4j.Level.ERROR) - assert(!log.isInfoEnabled()) - assert(log.isErrorEnabled()) + val current = org.apache.log4j.Logger.getRootLogger().getLevel() + try { + Utils.setLogLevel(org.apache.log4j.Level.ALL) + assert(log.isInfoEnabled()) + Utils.setLogLevel(org.apache.log4j.Level.ERROR) + assert(!log.isInfoEnabled()) + assert(log.isErrorEnabled()) + } finally { + // Best effort at undoing changes this test made. + Utils.setLogLevel(current) + } } test("deleteRecursively") { From f958f27e2056f9e380373c2807d8bb5977ecf269 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Jul 2015 16:43:18 -0700 Subject: [PATCH 085/274] [SPARK-8766] support non-ascii character in column names Use UTF-8 to encode the name of column in Python 2, or it may failed to encode with default encoding ('ascii'). This PR also fix a bug when there is Java exception without error message. Author: Davies Liu Closes #7165 from davies/non_ascii and squashes the following commits: 02cb61a [Davies Liu] fix tests 3b09d31 [Davies Liu] add encoding in header 867754a [Davies Liu] support non-ascii character in column names --- python/pyspark/sql/dataframe.py | 3 +-- python/pyspark/sql/tests.py | 9 +++++++++ python/pyspark/sql/types.py | 2 ++ python/pyspark/sql/utils.py | 6 +++--- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4b9efa0a210fb..273a40dd526cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -484,13 +484,12 @@ def dtypes(self): return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property - @ignore_unicode_prefix @since(1.3) def columns(self): """Returns all column names as a list. >>> df.columns - [u'age', u'name'] + ['age', 'name'] """ return [f.name for f in self.schema.fields] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5af2ce09bc122..333378c7f1854 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1,3 +1,4 @@ +# -*- encoding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -628,6 +629,14 @@ def test_access_column(self): self.assertRaises(IndexError, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) + def test_column_name_with_non_ascii(self): + df = self.sqlCtx.createDataFrame([(1,)], ["数量"]) + self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema) + self.assertEqual("DataFrame[数量: bigint]", str(df)) + self.assertEqual([("数量", 'bigint')], df.dtypes) + self.assertEqual(1, df.select("数量").first()[0]) + self.assertEqual(1, df.select(df["数量"]).first()[0]) + def test_access_nested_types(self): df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() self.assertEqual(1, df.select(df.l[0]).first()[0]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ae9344e6106a4..160df40d65cc1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -324,6 +324,8 @@ def __init__(self, name, dataType, nullable=True, metadata=None): False """ assert isinstance(dataType, DataType), "dataType should be DataType" + if not isinstance(name, str): + name = name.encode('utf-8') self.name = name self.dataType = dataType self.nullable = nullable diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 8096802e7302f..cc5b2c088b7cc 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -29,9 +29,9 @@ def deco(*a, **kw): try: return f(*a, **kw) except py4j.protocol.Py4JJavaError as e: - cls, msg = e.java_exception.toString().split(': ', 1) - if cls == 'org.apache.spark.sql.AnalysisException': - raise AnalysisException(msg) + s = e.java_exception.toString() + if s.startswith('org.apache.spark.sql.AnalysisException: '): + raise AnalysisException(s.split(': ', 1)[1]) raise return deco From 272778999823ed79af92280350c5869a87a21f29 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 16:56:48 -0700 Subject: [PATCH 086/274] [SPARK-8770][SQL] Create BinaryOperator abstract class. Our current BinaryExpression abstract class is not for generic binary expressions, i.e. it requires left/right children to have the same type. However, due to its name, contributors build new binary expressions that don't have that assumption (e.g. Sha) and still extend BinaryExpression. This patch creates a new BinaryOperator abstract class, and update the analyzer o only apply type casting rule there. This patch also adds the notion of "prettyName" to expressions, which defines the user-facing name for the expression. Author: Reynold Xin Closes #7170 from rxin/binaryoperator and squashes the following commits: 51264a5 [Reynold Xin] [SPARK-8770][SQL] Create BinaryOperator abstract class. --- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../expressions/ExpectsInputTypes.scala | 59 +++++++ .../sql/catalyst/expressions/Expression.scala | 161 +++++++++--------- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 6 - .../sql/catalyst/expressions/arithmetic.scala | 14 +- .../expressions/complexTypeCreator.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 2 - .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/expressions/sets.scala | 2 - .../expressions/stringOperations.scala | 26 +-- .../sql/catalyst/trees/TreeNodeSuite.scala | 6 +- 12 files changed, 170 insertions(+), 135 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2ab5cb666fbcd..8420c54f7c335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -150,6 +150,7 @@ object HiveTypeCoercion { * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to * the appropriate numeric equivalent. */ + // TODO: remove this rule and make Cast handle Nan. object ConvertNaNs extends Rule[LogicalPlan] { private val StringNaN = Literal("NaN") @@ -159,19 +160,19 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryExpression(StringNaN, right @ DoubleType()) => + case b @ BinaryOperator(StringNaN, right @ DoubleType()) => b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryExpression(left @ DoubleType(), StringNaN) => + case b @ BinaryOperator(left @ DoubleType(), StringNaN) => b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b @ BinaryExpression(StringNaN, right @ FloatType()) => + case b @ BinaryOperator(StringNaN, right @ FloatType()) => b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryExpression(left @ FloatType(), StringNaN) => + case b @ BinaryOperator(left @ FloatType(), StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryExpression(left @ StringNaN, StringNaN) => + case b @ BinaryOperator(left @ StringNaN, StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) } } @@ -245,12 +246,12 @@ object HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryExpressions. + // Also widen types for BinaryOperator. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) @@ -478,7 +479,7 @@ object HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala new file mode 100644 index 0000000000000..450fc4165f93b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types.DataType + + +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait AutoCastInputTypes { self: Expression => + + def inputTypes: Seq[DataType] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e18a3118945e8..cafbbafdca207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,17 +119,6 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) - /** - * Returns a string representation of this expression that does not have developer centric - * debugging information like the expression id. - */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString - } - /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). @@ -154,71 +143,40 @@ abstract class Expression extends TreeNode[Expression] { * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess -} - -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable - - override def toString: String = s"($left $symbol $right)" /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f accepts two variable names and returns Java code to compute the output. + * Returns a user-facing string representation of this expression's name. + * This should usually match the name of the function in SQL. */ - protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" - }) - } + def prettyName: String = getClass.getSimpleName.toLowerCase /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. + * Returns a user-facing string representation of this expression, i.e. does not have developer + * centric debugging information like the expression id. */ - protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString } -} -private[sql] object BinaryExpression { - def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) + override def toString: String = prettyName + children.mkString("(", ",", ")") } + +/** + * A leaf expression, i.e. one without any child expressions. + */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } + +/** + * An expression with one input and one output. The output is by default evaluated to null + * if the input is evaluated to null. + */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => @@ -265,39 +223,76 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } + /** - * An trait that gets mixin to define the expected input types of an expression. + * An expression with two inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. */ -trait ExpectsInputTypes { self: Expression => +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable || right.nullable /** - * Expected input types from child expressions. The i-th position in the returned seq indicates - * the type requirement for the i-th child. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. * - * The possible values at each position are: - * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + * @param f accepts two variable names and returns Java code to compute the output. */ - def inputTypes: Seq[Any] + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } - override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ } } + /** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. + * An expression that has two inputs that are expected to the be same type. If the two inputs have + * different types, the analyzer will find the tightest common type and do the proper type casting. */ -trait AutoCastInputTypes { self: Expression => +abstract class BinaryOperator extends BinaryExpression { + self: Product => - def inputTypes: Seq[DataType] + def symbol: String - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.TypeCheckSuccess - } + override def toString: String = s"($left $symbol $right)" +} + + +private[sql] object BinaryOperator { + def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index ebabb6f117851..caf021b016a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi override def nullable: Boolean = true - override def toString: String = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index a9fc54c548f49..da520f56b430e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -128,7 +128,6 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -162,7 +161,6 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def nullable: Boolean = false override def dataType: LongType.type = LongType - override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -401,8 +399,6 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString: String = s"AVG($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -494,8 +490,6 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString: String = s"SUM($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5363b3556886a..4fbf4c87009c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { } case class UnaryPositive(child: Expression) extends UnaryArithmetic { - override def toString: String = s"positive($child)" + override def prettyName: String = "positive" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -69,8 +69,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { - override def toString: String = s"Abs($child)" - override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function abs") @@ -79,10 +77,9 @@ case class Abs(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryExpression { +abstract class BinaryArithmetic extends BinaryOperator { self: Product => - override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -360,7 +357,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } - override def toString: String = s"MaxOf($left, $right)" + + override def symbol: String = "max" + override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -413,5 +412,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { """ } - override def toString: String = s"MinOf($left, $right)" + override def symbol: String = "min" + override def prettyName: String = symbol } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5def57b067424..67e7dc4ec8b14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString: String = s"Array(${children.mkString(",")})" + override def prettyName: String = "array" } /** @@ -71,4 +71,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(children.map(_.eval(input)): _*) } + + override def prettyName: String = "struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 78be2824347d7..145d323a9f0bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -38,8 +38,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a777f77add2db..34df89a163895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,7 +120,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -169,7 +169,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -217,7 +217,7 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryExpression with Predicate { +abstract class BinaryComparison extends BinaryOperator with Predicate { self: Product => override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index daa9f4403ffab..5d51a4ca65332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -137,8 +137,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = left.dataType - override def symbol: String = "++=" - override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 4cbfc4e084948..b020f2bbc5818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -75,8 +75,6 @@ trait StringRegexExpression extends AutoCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "LIKE" - // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -101,14 +99,16 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends AutoCastInputTypes { @@ -134,9 +134,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase() - - override def toString: String = s"Upper($child)" + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -148,9 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase() - - override def toString: String = s"Lower($child)" + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -178,8 +174,6 @@ trait StringComparison extends AutoCastInputTypes { } } - override def symbol: String = nodeName - override def toString: String = s"$nodeName($left, $right)" } @@ -284,12 +278,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } } - - override def toString: String = len match { - // TODO: This is broken because max is not an integer value. - case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" - case _ => s"SUBSTR($str, $pos, $len)" - } } /** @@ -304,9 +292,9 @@ case class StringLength(child: Expression) extends UnaryExpression with AutoCast if (string == null) null else string.asInstanceOf[UTF8String].length } - override def toString: String = s"length($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") } + + override def prettyName: String = "length" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index bda217935cb05..86792f0217572 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -73,7 +73,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -85,7 +85,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -125,7 +125,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryExpression => actual.append(b.symbol); + case b: BinaryOperator => actual.append(b.symbol); case l: Literal => actual.append(l.toString); } From 3a342dedc04799948bf6da69843bd1a91202ffe5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 16:59:39 -0700 Subject: [PATCH 087/274] Revert "[SPARK-8770][SQL] Create BinaryOperator abstract class." This reverts commit 272778999823ed79af92280350c5869a87a21f29. --- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../expressions/ExpectsInputTypes.scala | 59 ------- .../sql/catalyst/expressions/Expression.scala | 161 +++++++++--------- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 6 + .../sql/catalyst/expressions/arithmetic.scala | 14 +- .../expressions/complexTypeCreator.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 2 + .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/expressions/sets.scala | 2 + .../expressions/stringOperations.scala | 26 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 6 +- 12 files changed, 135 insertions(+), 170 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8420c54f7c335..2ab5cb666fbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -150,7 +150,6 @@ object HiveTypeCoercion { * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to * the appropriate numeric equivalent. */ - // TODO: remove this rule and make Cast handle Nan. object ConvertNaNs extends Rule[LogicalPlan] { private val StringNaN = Literal("NaN") @@ -160,19 +159,19 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryOperator(StringNaN, right @ DoubleType()) => + case b @ BinaryExpression(StringNaN, right @ DoubleType()) => b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryOperator(left @ DoubleType(), StringNaN) => + case b @ BinaryExpression(left @ DoubleType(), StringNaN) => b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b @ BinaryOperator(StringNaN, right @ FloatType()) => + case b @ BinaryExpression(StringNaN, right @ FloatType()) => b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryOperator(left @ FloatType(), StringNaN) => + case b @ BinaryExpression(left @ FloatType(), StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryOperator(left @ StringNaN, StringNaN) => + case b @ BinaryExpression(left @ StringNaN, StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) } } @@ -246,12 +245,12 @@ object HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryOperator. + // Also widen types for BinaryExpressions. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) @@ -479,7 +478,7 @@ object HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala deleted file mode 100644 index 450fc4165f93b..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.types.DataType - - -/** - * An trait that gets mixin to define the expected input types of an expression. - */ -trait ExpectsInputTypes { self: Expression => - - /** - * Expected input types from child expressions. The i-th position in the returned seq indicates - * the type requirement for the i-th child. - * - * The possible values at each position are: - * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). - */ - def inputTypes: Seq[Any] - - override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess - } -} - -/** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. - */ -trait AutoCastInputTypes { self: Expression => - - def inputTypes: Seq[DataType] - - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.TypeCheckSuccess - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index cafbbafdca207..e18a3118945e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,6 +119,17 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) + /** + * Returns a string representation of this expression that does not have developer centric + * debugging information like the expression id. + */ + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString + } + /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). @@ -143,40 +154,71 @@ abstract class Expression extends TreeNode[Expression] { * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess +} + +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable || right.nullable + + override def toString: String = s"($left $symbol $right)" /** - * Returns a user-facing string representation of this expression's name. - * This should usually match the name of the function in SQL. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. */ - def prettyName: String = getClass.getSimpleName.toLowerCase + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } /** - * Returns a user-facing string representation of this expression, i.e. does not have developer - * centric debugging information like the expression id. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ } - - override def toString: String = prettyName + children.mkString("(", ",", ")") } +private[sql] object BinaryExpression { + def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) +} -/** - * A leaf expression, i.e. one without any child expressions. - */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } - -/** - * An expression with one input and one output. The output is by default evaluated to null - * if the input is evaluated to null. - */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => @@ -223,76 +265,39 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } - /** - * An expression with two inputs and one output. The output is by default evaluated to null - * if any input is evaluated to null. + * An trait that gets mixin to define the expected input types of an expression. */ -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable +trait ExpectsInputTypes { self: Expression => /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. * - * @param f accepts two variable names and returns Java code to compute the output. + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). */ - protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" - }) - } + def inputTypes: Seq[Any] - /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - */ - protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess } } - /** - * An expression that has two inputs that are expected to the be same type. If the two inputs have - * different types, the analyzer will find the tightest common type and do the proper type casting. + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. */ -abstract class BinaryOperator extends BinaryExpression { - self: Product => +trait AutoCastInputTypes { self: Expression => - def symbol: String + def inputTypes: Seq[DataType] - override def toString: String = s"($left $symbol $right)" -} - - -private[sql] object BinaryOperator { - def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index caf021b016a41..ebabb6f117851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi override def nullable: Boolean = true - override def toString: String = s"UDF(${children.mkString(",")})" + override def toString: String = s"scalaUDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index da520f56b430e..a9fc54c548f49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -128,6 +128,7 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType + override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -161,6 +162,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def nullable: Boolean = false override def dataType: LongType.type = LongType + override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -399,6 +401,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } + override def toString: String = s"AVG($child)" + override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -490,6 +494,8 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } + override def toString: String = s"SUM($child)" + override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4fbf4c87009c2..5363b3556886a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { } case class UnaryPositive(child: Expression) extends UnaryArithmetic { - override def prettyName: String = "positive" + override def toString: String = s"positive($child)" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -69,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { + override def toString: String = s"Abs($child)" + override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function abs") @@ -77,9 +79,10 @@ case class Abs(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryOperator { +abstract class BinaryArithmetic extends BinaryExpression { self: Product => + override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -357,9 +360,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } - - override def symbol: String = "max" - override def prettyName: String = symbol + override def toString: String = s"MaxOf($left, $right)" } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -412,6 +413,5 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { """ } - override def symbol: String = "min" - override def prettyName: String = symbol + override def toString: String = s"MinOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67e7dc4ec8b14..5def57b067424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def prettyName: String = "array" + override def toString: String = s"Array(${children.mkString(",")})" } /** @@ -71,6 +71,4 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(children.map(_.eval(input)): _*) } - - override def prettyName: String = "struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 145d323a9f0bb..78be2824347d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -38,6 +38,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } + override def toString: String = s"Coalesce(${children.mkString(",")})" + override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 34df89a163895..a777f77add2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,7 +120,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -169,7 +169,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -217,7 +217,7 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryOperator with Predicate { +abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 5d51a4ca65332..daa9f4403ffab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -137,6 +137,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = left.dataType + override def symbol: String = "++=" + override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index b020f2bbc5818..4cbfc4e084948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -75,6 +75,8 @@ trait StringRegexExpression extends AutoCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { + override def symbol: String = "LIKE" + // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -99,16 +101,14 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() - - override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { + override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) - override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends AutoCastInputTypes { @@ -134,7 +134,9 @@ trait CaseConversionExpression extends AutoCastInputTypes { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase + override def convert(v: UTF8String): UTF8String = v.toUpperCase() + + override def toString: String = s"Upper($child)" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -146,7 +148,9 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase + override def convert(v: UTF8String): UTF8String = v.toLowerCase() + + override def toString: String = s"Lower($child)" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -174,6 +178,8 @@ trait StringComparison extends AutoCastInputTypes { } } + override def symbol: String = nodeName + override def toString: String = s"$nodeName($left, $right)" } @@ -278,6 +284,12 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } } + + override def toString: String = len match { + // TODO: This is broken because max is not an integer value. + case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" + case _ => s"SUBSTR($str, $pos, $len)" + } } /** @@ -292,9 +304,9 @@ case class StringLength(child: Expression) extends UnaryExpression with AutoCast if (string == null) null else string.asInstanceOf[UTF8String].length } + override def toString: String = s"length($child)" + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") } - - override def prettyName: String = "length" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 86792f0217572..bda217935cb05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -73,7 +73,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryOperator => actual.append(b.symbol); b + case b: BinaryExpression => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -85,7 +85,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryOperator => actual.append(b.symbol); b + case b: BinaryExpression => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -125,7 +125,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryOperator => actual.append(b.symbol); + case b: BinaryExpression => actual.append(b.symbol); case l: Literal => actual.append(l.toString); } From 9fd13d5613b6d16a78d97d4798f085b56107d343 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 21:14:13 -0700 Subject: [PATCH 088/274] [SPARK-8770][SQL] Create BinaryOperator abstract class. Our current BinaryExpression abstract class is not for generic binary expressions, i.e. it requires left/right children to have the same type. However, due to its name, contributors build new binary expressions that don't have that assumption (e.g. Sha) and still extend BinaryExpression. This patch creates a new BinaryOperator abstract class, and update the analyzer o only apply type casting rule there. This patch also adds the notion of "prettyName" to expressions, which defines the user-facing name for the expression. Author: Reynold Xin Closes #7174 from rxin/binary-opterator and squashes the following commits: f31900d [Reynold Xin] [SPARK-8770][SQL] Create BinaryOperator abstract class. fceb216 [Reynold Xin] Merge branch 'master' of github.com:apache/spark into binary-opterator d8518cf [Reynold Xin] Updated Python tests. --- python/pyspark/sql/dataframe.py | 10 +- python/pyspark/sql/functions.py | 4 +- python/pyspark/sql/group.py | 24 +-- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../expressions/ExpectsInputTypes.scala | 59 +++++++ .../sql/catalyst/expressions/Expression.scala | 161 +++++++++--------- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 9 +- .../sql/catalyst/expressions/arithmetic.scala | 14 +- .../expressions/complexTypeCreator.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 2 - .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/expressions/sets.scala | 2 - .../expressions/stringOperations.scala | 26 +-- .../sql/catalyst/trees/TreeNodeSuite.scala | 6 +- 15 files changed, 191 insertions(+), 155 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 273a40dd526cf..1e9c657cf81b3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -802,11 +802,11 @@ def groupBy(self, *cols): Each element should be a column name (string) or an expression (:class:`Column`). >>> df.groupBy().avg().collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(['name', df.age]).count().collect() [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ @@ -864,10 +864,10 @@ def agg(self, *exprs): (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() - [Row(MAX(age)=5)] + [Row(max(age)=5)] >>> from pyspark.sql import functions as F >>> df.agg(F.min(df.age)).collect() - [Row(MIN(age)=2)] + [Row(min(age)=2)] """ return self.groupBy().agg(*exprs) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4e2be88e9e3b9..f9a15d4a66309 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -266,7 +266,7 @@ def coalesce(*cols): >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show() +-------------+ - |Coalesce(a,b)| + |coalesce(a,b)| +-------------+ | null| | 1| @@ -275,7 +275,7 @@ def coalesce(*cols): >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() +----+----+---------------+ - | a| b|Coalesce(a,0.0)| + | a| b|coalesce(a,0.0)| +----+----+---------------+ |null|null| 0.0| | 1|null| 1.0| diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 5a37a673ee80c..04594d5a836ce 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -75,11 +75,11 @@ def agg(self, *exprs): >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] + [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] + [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -110,9 +110,9 @@ def mean(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().mean('age').collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df3.groupBy().mean('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] + [Row(avg(age)=3.5, avg(height)=82.5)] """ @df_varargs_api @@ -125,9 +125,9 @@ def avg(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().avg('age').collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df3.groupBy().avg('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] + [Row(avg(age)=3.5, avg(height)=82.5)] """ @df_varargs_api @@ -136,9 +136,9 @@ def max(self, *cols): """Computes the max value for each numeric columns for each group. >>> df.groupBy().max('age').collect() - [Row(MAX(age)=5)] + [Row(max(age)=5)] >>> df3.groupBy().max('age', 'height').collect() - [Row(MAX(age)=5, MAX(height)=85)] + [Row(max(age)=5, max(height)=85)] """ @df_varargs_api @@ -149,9 +149,9 @@ def min(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().min('age').collect() - [Row(MIN(age)=2)] + [Row(min(age)=2)] >>> df3.groupBy().min('age', 'height').collect() - [Row(MIN(age)=2, MIN(height)=80)] + [Row(min(age)=2, min(height)=80)] """ @df_varargs_api @@ -162,9 +162,9 @@ def sum(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().sum('age').collect() - [Row(SUM(age)=7)] + [Row(sum(age)=7)] >>> df3.groupBy().sum('age', 'height').collect() - [Row(SUM(age)=7, SUM(height)=165)] + [Row(sum(age)=7, sum(height)=165)] """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2ab5cb666fbcd..8420c54f7c335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -150,6 +150,7 @@ object HiveTypeCoercion { * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to * the appropriate numeric equivalent. */ + // TODO: remove this rule and make Cast handle Nan. object ConvertNaNs extends Rule[LogicalPlan] { private val StringNaN = Literal("NaN") @@ -159,19 +160,19 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryExpression(StringNaN, right @ DoubleType()) => + case b @ BinaryOperator(StringNaN, right @ DoubleType()) => b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryExpression(left @ DoubleType(), StringNaN) => + case b @ BinaryOperator(left @ DoubleType(), StringNaN) => b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b @ BinaryExpression(StringNaN, right @ FloatType()) => + case b @ BinaryOperator(StringNaN, right @ FloatType()) => b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryExpression(left @ FloatType(), StringNaN) => + case b @ BinaryOperator(left @ FloatType(), StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryExpression(left @ StringNaN, StringNaN) => + case b @ BinaryOperator(left @ StringNaN, StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) } } @@ -245,12 +246,12 @@ object HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryExpressions. + // Also widen types for BinaryOperator. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) @@ -478,7 +479,7 @@ object HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala new file mode 100644 index 0000000000000..450fc4165f93b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types.DataType + + +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait AutoCastInputTypes { self: Expression => + + def inputTypes: Seq[DataType] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e18a3118945e8..cafbbafdca207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,17 +119,6 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) - /** - * Returns a string representation of this expression that does not have developer centric - * debugging information like the expression id. - */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString - } - /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). @@ -154,71 +143,40 @@ abstract class Expression extends TreeNode[Expression] { * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess -} - -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable - - override def toString: String = s"($left $symbol $right)" /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f accepts two variable names and returns Java code to compute the output. + * Returns a user-facing string representation of this expression's name. + * This should usually match the name of the function in SQL. */ - protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" - }) - } + def prettyName: String = getClass.getSimpleName.toLowerCase /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. + * Returns a user-facing string representation of this expression, i.e. does not have developer + * centric debugging information like the expression id. */ - protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString } -} -private[sql] object BinaryExpression { - def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) + override def toString: String = prettyName + children.mkString("(", ",", ")") } + +/** + * A leaf expression, i.e. one without any child expressions. + */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } + +/** + * An expression with one input and one output. The output is by default evaluated to null + * if the input is evaluated to null. + */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => @@ -265,39 +223,76 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } + /** - * An trait that gets mixin to define the expected input types of an expression. + * An expression with two inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. */ -trait ExpectsInputTypes { self: Expression => +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable || right.nullable /** - * Expected input types from child expressions. The i-th position in the returned seq indicates - * the type requirement for the i-th child. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. * - * The possible values at each position are: - * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + * @param f accepts two variable names and returns Java code to compute the output. */ - def inputTypes: Seq[Any] + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } - override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ } } + /** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. + * An expression that has two inputs that are expected to the be same type. If the two inputs have + * different types, the analyzer will find the tightest common type and do the proper type casting. */ -trait AutoCastInputTypes { self: Expression => +abstract class BinaryOperator extends BinaryExpression { + self: Product => - def inputTypes: Seq[DataType] + def symbol: String - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.TypeCheckSuccess - } + override def toString: String = s"($left $symbol $right)" +} + + +private[sql] object BinaryOperator { + def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index ebabb6f117851..caf021b016a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi override def nullable: Boolean = true - override def toString: String = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index a9fc54c548f49..64e07bd2a17db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -94,7 +94,6 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { val partialMin = Alias(Min(child), "PartialMin")() @@ -128,7 +127,6 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -162,7 +160,6 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def nullable: Boolean = false override def dataType: LongType.type = LongType - override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -390,6 +387,8 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def prettyName: String = "avg" + override def nullable: Boolean = true override def dataType: DataType = child.dataType match { @@ -401,8 +400,6 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString: String = s"AVG($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -494,8 +491,6 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString: String = s"SUM($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5363b3556886a..4fbf4c87009c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { } case class UnaryPositive(child: Expression) extends UnaryArithmetic { - override def toString: String = s"positive($child)" + override def prettyName: String = "positive" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -69,8 +69,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { - override def toString: String = s"Abs($child)" - override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function abs") @@ -79,10 +77,9 @@ case class Abs(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryExpression { +abstract class BinaryArithmetic extends BinaryOperator { self: Product => - override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -360,7 +357,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } - override def toString: String = s"MaxOf($left, $right)" + + override def symbol: String = "max" + override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -413,5 +412,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { """ } - override def toString: String = s"MinOf($left, $right)" + override def symbol: String = "min" + override def prettyName: String = symbol } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5def57b067424..67e7dc4ec8b14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString: String = s"Array(${children.mkString(",")})" + override def prettyName: String = "array" } /** @@ -71,4 +71,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(children.map(_.eval(input)): _*) } + + override def prettyName: String = "struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 78be2824347d7..145d323a9f0bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -38,8 +38,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a777f77add2db..34df89a163895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,7 +120,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -169,7 +169,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -217,7 +217,7 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryExpression with Predicate { +abstract class BinaryComparison extends BinaryOperator with Predicate { self: Product => override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index daa9f4403ffab..5d51a4ca65332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -137,8 +137,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = left.dataType - override def symbol: String = "++=" - override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 4cbfc4e084948..b020f2bbc5818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -75,8 +75,6 @@ trait StringRegexExpression extends AutoCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "LIKE" - // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -101,14 +99,16 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends AutoCastInputTypes { @@ -134,9 +134,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase() - - override def toString: String = s"Upper($child)" + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -148,9 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase() - - override def toString: String = s"Lower($child)" + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -178,8 +174,6 @@ trait StringComparison extends AutoCastInputTypes { } } - override def symbol: String = nodeName - override def toString: String = s"$nodeName($left, $right)" } @@ -284,12 +278,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } } - - override def toString: String = len match { - // TODO: This is broken because max is not an integer value. - case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" - case _ => s"SUBSTR($str, $pos, $len)" - } } /** @@ -304,9 +292,9 @@ case class StringLength(child: Expression) extends UnaryExpression with AutoCast if (string == null) null else string.asInstanceOf[UTF8String].length } - override def toString: String = s"length($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") } + + override def prettyName: String = "length" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index bda217935cb05..86792f0217572 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -73,7 +73,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -85,7 +85,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -125,7 +125,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryExpression => actual.append(b.symbol); + case b: BinaryOperator => actual.append(b.symbol); case l: Literal => actual.append(l.toString); } From 4e4f74b5e1267d1ada4a8f57b86aee0d9c17d90a Mon Sep 17 00:00:00 2001 From: Rosstin Date: Wed, 1 Jul 2015 21:42:06 -0700 Subject: [PATCH 089/274] [SPARK-8660] [MLLIB] removed > symbols from comments in LogisticRegressionSuite.scala for ease of copypaste '>' symbols removed from comments in LogisticRegressionSuite.scala, for ease of copypaste also single-lined the multiline commands (is this desirable, or does it violate style?) Author: Rosstin Closes #7167 from Rosstin/SPARK-8660-2 and squashes the following commits: f4b9bc8 [Rosstin] SPARK-8660 restored character limit on multiline comments in LogisticRegressionSuite.scala fe6b112 [Rosstin] SPARK-8660 > symbols removed from LogisticRegressionSuite.scala for easy of copypaste 39ddd50 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8661 5a05dee [Rosstin] SPARK-8661 for LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments to make it easier to copy-paste the R code. bb9a4b1 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8660 242aedd [Rosstin] SPARK-8660, changed comment style from JavaDoc style to normal multiline comment in order to make copypaste into R easier, in file classification/LogisticRegressionSuite.scala 2cd2985 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 21ac1e5 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 6c18058 [Rosstin] fixed minor typos in docs/README.md and docs/api.md --- .../LogisticRegressionSuite.scala | 117 ++++++++++-------- 1 file changed, 63 insertions(+), 54 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bc6eeac1db5da..ba8fbee84197c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -214,12 +214,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { /* Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - > weights + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + weights + 5 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 2.8366423 @@ -245,13 +246,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { /* Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > weights = + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - > weights + weights + 5 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -278,12 +280,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { /* Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - > weights + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + weights + 5 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) -0.05627428 @@ -310,13 +313,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { /* Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, intercept=FALSE)) - > weights + weights + 5 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -343,12 +347,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { /* Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - > weights + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + weights + 5 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 0.15021751 @@ -375,13 +380,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { /* Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, intercept=FALSE)) - > weights + weights + 5 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -408,12 +414,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { /* Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - > weights + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) + weights + 5 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 0.57734851 @@ -440,13 +447,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { /* Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, intercept=FALSE)) - > weights + weights + 5 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -503,12 +511,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { /* Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) - > weights + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + weights + 5 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) -0.2480643 From b285ac5ba85fe0b32b00726ad7d3a2efb602e885 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 1 Jul 2015 22:19:51 -0700 Subject: [PATCH 090/274] [SPARK-8227] [SQL] Add function unhex cc chenghao-intel adrian-wang Author: zhichao.li Closes #7113 from zhichao-li/unhex and squashes the following commits: 379356e [zhichao.li] remove exception checking a4ae6dc [zhichao.li] add udf_unhex to whitelist fe5c14a [zhichao.li] add todigit 607d7a3 [zhichao.li] use checkInputTypes bffd37f [zhichao.li] change to use Hex in apache common package cde73f5 [zhichao.li] update to use AutoCastInputTypes 11945c7 [zhichao.li] style c852d46 [zhichao.li] Add function unhex --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 52 +++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 6 +++ .../org/apache/spark/sql/functions.scala | 18 +++++++ .../spark/sql/MathExpressionsSuite.scala | 10 ++++ .../execution/HiveCompatibilitySuite.scala | 1 + 6 files changed, 88 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d53eaedda56b0..6f04298d4711b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -157,6 +157,7 @@ object FunctionRegistry { expression[Substring]("substr"), expression[Substring]("substring"), expression[Upper]("ucase"), + expression[UnHex]("unhex"), expression[Upper]("upper") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index b51318dd5044c..8633eb06ffee4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -351,6 +351,58 @@ case class Pow(left: Expression, right: Expression) } } +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class UnHex(child: Expression) extends UnaryExpression with Serializable { + + override def dataType: DataType = BinaryType + + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}") + } + } + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + unhex(num.asInstanceOf[UTF8String].getBytes) + } + } + + private val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } + + private def unhex(inputBytes: Array[Byte]): Array[Byte] = { + var bytes = inputBytes + if ((bytes.length & 0x01) != 0) { + bytes = '0'.toByte +: bytes + } + val out = new Array[Byte](bytes.length >> 1) + // two characters form the hex value. + var i = 0 + while (i < bytes.length) { + val first = unhexDigits(bytes(i)) + val second = unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { return null} + out(i / 2) = (((first << 4) | second) & 0xFF).toByte + i += 2 + } + out + } +} + case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index b932d4ab850c7..b3345d7069159 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -238,6 +238,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("unhex") { + checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(UnHex(Literal("")), new Array[Byte](0)) + checkEvaluation(UnHex(Literal("0")), Array[Byte](0)) + } + test("hypot") { testBinary(Hypot, math.hypot) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4e8f3f96bf4db..e6f623bdf39eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1053,6 +1053,24 @@ object functions { */ def hex(colName: String): Column = hex(Column(colName)) + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number + * and converts to the byte representation of number. + * + * @group math_funcs + * @since 1.5.0 + */ + def unhex(column: Column): Column = UnHex(column.expr) + + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number + * and converts to the byte representation of number. + * + * @group math_funcs + * @since 1.5.0 + */ + def unhex(colName: String): Column = unhex(Column(colName)) + /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index d6331aa4ff09e..c03cde38d75d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -225,6 +225,16 @@ class MathExpressionsSuite extends QueryTest { checkAnswer(data.selectExpr("hex(cast(d as binary))"), Seq(Row("68656C6C6F"))) } + test("unhex") { + val data = Seq(("1C", "737472696E67")).toDF("a", "b") + checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) + checkAnswer(data.select(unhex('b)), Row("string".getBytes)) + checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) + checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes)) + checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) + checkAnswer(data.selectExpr("""unhex("G123")"""), Row(null)) + } + test("hypot") { testTwoToOneMathFunction(hypot, hypot, math.hypot) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f88e62763ca70..415a81644c58f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -949,6 +949,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_trim", "udf_ucase", "udf_unix_timestamp", + "udf_unhex", "udf_upper", "udf_var_pop", "udf_var_samp", From 792fcd802c99a0aef2b67d54f0e6e58710e65956 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Wed, 1 Jul 2015 22:59:04 -0700 Subject: [PATCH 091/274] [SPARK-8754] [YARN] YarnClientSchedulerBackend doesn't stop gracefully in failure conditions In YarnClientSchedulerBackend.stop(), added a check for monitorThread. Author: Devaraj K Closes #7153 from devaraj-kavali/master and squashes the following commits: 66be9ad [Devaraj K] https://issues.apache.org/jira/browse/SPARK-8754 YarnClientSchedulerBackend doesn't stop gracefully in failure conditions --- .../spark/scheduler/cluster/YarnClientSchedulerBackend.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 1c8d7ec57635f..dd8c4fdb549ed 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -148,7 +148,9 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") - monitorThread.interrupt() + if (monitorThread != null) { + monitorThread.interrupt() + } super.stop() client.stop() logInfo("Stopped") From 646366b5d2f12e42f8e7287672ba29a8c918a17d Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Wed, 1 Jul 2015 23:01:44 -0700 Subject: [PATCH 092/274] [SPARK-8688] [YARN] Bug fix: disable the cache fs to gain the HDFS connection. If `fs.hdfs.impl.disable.cache` was `false`(default), `FileSystem` will use the cached `DFSClient` which use old token. [AMDelegationTokenRenewer](https://github.com/apache/spark/blob/master/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala#L196) ```scala val credentials = UserGroupInformation.getCurrentUser.getCredentials credentials.writeTokenStorageFile(tempTokenPath, discachedConfiguration) ``` Although the `credentials` had the new Token, but it still use the cached client and old token. So It's better to set the `fs.hdfs.impl.disable.cache` as `true` to avoid token expired. [Jira](https://issues.apache.org/jira/browse/SPARK-8688) Author: huangzhaowei Closes #7069 from SaintBacchus/SPARK-8688 and squashes the following commits: f94cd0b [huangzhaowei] modify function parameter 8fb9eb9 [huangzhaowei] explicit the comment 0cd55c9 [huangzhaowei] Rename function name to be an accurate one cf776a1 [huangzhaowei] [SPARK-8688][YARN]Bug fix: disable the cache fs to gain the HDFS connection. --- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 13 +++++++++++++ .../deploy/yarn/AMDelegationTokenRenewer.scala | 10 ++++++---- .../yarn/ExecutorDelegationTokenUpdater.scala | 5 ++++- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 7fa75ac8c2b54..6d14590a1d192 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -334,6 +334,19 @@ class SparkHadoopUtil extends Logging { * Stop the thread that does the delegation token updates. */ private[spark] def stopExecutorDelegationTokenRenewer() {} + + /** + * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. + * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. + */ + private[spark] def getConfBypassingFSCache( + hadoopConf: Configuration, + scheme: String): Configuration = { + val newConf = new Configuration(hadoopConf) + val confKey = s"fs.${scheme}.impl.disable.cache" + newConf.setBoolean(confKey, true) + newConf + } } object SparkHadoopUtil { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index 77af46c192cc2..56e4741b93873 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -65,6 +65,8 @@ private[yarn] class AMDelegationTokenRenewer( sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val freshHadoopConf = + hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -123,7 +125,7 @@ private[yarn] class AMDelegationTokenRenewer( private def cleanupOldFiles(): Unit = { import scala.concurrent.duration._ try { - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) val credentialsPath = new Path(credentialsFile) val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis hadoopUtil.listFilesSorted( @@ -169,13 +171,13 @@ private[yarn] class AMDelegationTokenRenewer( // Get a copy of the credentials override def run(): Void = { val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst - hadoopUtil.obtainTokensForNamenodes(nns, hadoopConf, tempCreds) + hadoopUtil.obtainTokensForNamenodes(nns, freshHadoopConf, tempCreds) null } }) // Add the temp credentials back to the original ones. UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file // and update the lastCredentialsFileSuffix. @@ -194,7 +196,7 @@ private[yarn] class AMDelegationTokenRenewer( val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) logInfo("Writing out delegation tokens to " + tempTokenPath.toString) val credentials = UserGroupInformation.getCurrentUser.getCredentials - credentials.writeTokenStorageFile(tempTokenPath, hadoopConf) + credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") remoteFs.rename(tempTokenPath, tokenPath) logInfo("Delegation token file rename complete.") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala index 229c2c4d5eb36..94feb6393fd69 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -35,6 +35,9 @@ private[spark] class ExecutorDelegationTokenUpdater( @volatile private var lastCredentialsFileSuffix = 0 private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val freshHadoopConf = + SparkHadoopUtil.get.getConfBypassingFSCache( + hadoopConf, new Path(credentialsFile).toUri.getScheme) private val delegationTokenRenewer = Executors.newSingleThreadScheduledExecutor( @@ -49,7 +52,7 @@ private[spark] class ExecutorDelegationTokenUpdater( def updateCredentialsIfRequired(): Unit = { try { val credentialsFilePath = new Path(credentialsFile) - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) SparkHadoopUtil.get.listFilesSorted( remoteFs, credentialsFilePath.getParent, credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) From d14338eafc5d633f766bd52ba610fd7c4fe90581 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 1 Jul 2015 23:04:05 -0700 Subject: [PATCH 093/274] [SPARK-8771] [TRIVIAL] Add a version to the deprecated annotation for the actorSystem Author: Holden Karau Closes #7172 from holdenk/SPARK-8771-actor-system-deprecation-tag-uses-deprecated-deprecation-tag and squashes the following commits: 7f1455b [Holden Karau] Add .0s to the versions for the derpecated anotations in SparkEnv.scala ca13c9d [Holden Karau] Add a version to the deprecated annotation for the actorSystem in SparkEnv --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 1b133fbdfaf59..d18fc599e9890 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -76,7 +76,7 @@ class SparkEnv ( val conf: SparkConf) extends Logging { // TODO Remove actorSystem - @deprecated("Actor system is no longer supported as of 1.4") + @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private[spark] var isStopped = false @@ -173,7 +173,7 @@ object SparkEnv extends Logging { /** * Returns the ThreadLocal SparkEnv. */ - @deprecated("Use SparkEnv.get instead", "1.2") + @deprecated("Use SparkEnv.get instead", "1.2.0") def getThreadLocal: SparkEnv = { env } From 15d41cc501f5fa7ac82c4a6741e416bb557f610a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 1 Jul 2015 23:05:45 -0700 Subject: [PATCH 094/274] [SPARK-8769] [TRIVIAL] [DOCS] toLocalIterator should mention it results in many jobs Author: Holden Karau Closes #7171 from holdenk/SPARK-8769-toLocalIterator-documentation-improvement and squashes the following commits: 97ddd99 [Holden Karau] Add note --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 10610f4b6f1ff..cac6e3b477e16 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -890,6 +890,10 @@ abstract class RDD[T: ClassTag]( * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. + * + * Note: this results in multiple Spark jobs, and if the input RDD is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input RDD should be cached first. */ def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { From 377ff4c9e8942882183d94698684824e9dc9f391 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 23:06:52 -0700 Subject: [PATCH 095/274] [SPARK-8740] [PROJECT INFRA] Support GitHub OAuth tokens in dev/merge_spark_pr.py This commit allows `dev/merge_spark_pr.py` to use personal GitHub OAuth tokens in order to make authenticated requests. This is necessary to work around per-IP rate limiting issues. To use a token, just set the `GITHUB_OAUTH_KEY` environment variable. You can create a personal token at https://github.com/settings/tokens; we only require `public_repo` scope. If the script fails due to a rate-limit issue, it now logs a useful message directing the user to the OAuth token instructions. Author: Josh Rosen Closes #7136 from JoshRosen/pr-merge-script-oauth-authentication and squashes the following commits: 4d011bd [Josh Rosen] Fix error message 23d92ff [Josh Rosen] Support GitHub OAuth tokens in dev/merge_spark_pr.py --- dev/merge_spark_pr.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index cf827ce89b857..4a17d48d8171d 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -47,6 +47,12 @@ JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") # ASF JIRA password JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") +# OAuth key used for issuing requests against the GitHub API. If this is not defined, then requests +# will be unauthenticated. You should only need to configure this if you find yourself regularly +# exceeding your IP's unauthenticated request rate limit. You can create an OAuth key at +# https://github.com/settings/tokens. This script only requires the "public_repo" scope. +GITHUB_OAUTH_KEY = os.environ.get("GITHUB_OAUTH_KEY") + GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" @@ -58,9 +64,17 @@ def get_json(url): try: - return json.load(urllib2.urlopen(url)) + request = urllib2.Request(url) + if GITHUB_OAUTH_KEY: + request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY) + return json.load(urllib2.urlopen(request)) except urllib2.HTTPError as e: - print "Unable to fetch URL, exiting: %s" % url + if "X-RateLimit-Remaining" in e.headers and e.headers["X-RateLimit-Remaining"] == '0': + print "Exceeded the GitHub API rate limit; see the instructions in " + \ + "dev/merge_spark_pr.py to configure an OAuth token for making authenticated " + \ + "GitHub requests." + else: + print "Unable to fetch URL, exiting: %s" % url sys.exit(-1) From 3697232b7d438979cc119b2a364296b0eec4a16a Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Wed, 1 Jul 2015 23:11:02 -0700 Subject: [PATCH 096/274] [SPARK-3071] Increase default driver memory I've updated default values in comments, documentation, and in the command line builder to be 1g based on comments in the JIRA. I've also updated most usages to point at a single variable defined in the Utils.scala and JavaUtils.java files. This wasn't possible in all cases (R, shell scripts etc.) but usage in most code is now pointing at the same place. Please let me know if I've missed anything. Will the spark-shell use the value within the command line builder during instantiation? Author: Ilya Ganelin Closes #7132 from ilganeli/SPARK-3071 and squashes the following commits: 4074164 [Ilya Ganelin] String fix 271610b [Ilya Ganelin] Merge branch 'SPARK-3071' of github.com:ilganeli/spark into SPARK-3071 273b6e9 [Ilya Ganelin] Test fix fd67721 [Ilya Ganelin] Update JavaUtils.java 26cc177 [Ilya Ganelin] test fix e5db35d [Ilya Ganelin] Fixed test failure 39732a1 [Ilya Ganelin] merge fix a6f7deb [Ilya Ganelin] Created default value for DRIVER MEM in Utils that's now used in almost all locations instead of setting manually in each 09ad698 [Ilya Ganelin] Update SubmitRestProtocolSuite.scala 19b6f25 [Ilya Ganelin] Missed one doc update 2698a3d [Ilya Ganelin] Updated default value for driver memory --- R/pkg/R/sparkR.R | 2 +- conf/spark-env.sh.template | 2 +- .../org/apache/spark/deploy/ClientArguments.scala | 2 +- .../org/apache/spark/deploy/SparkSubmitArguments.scala | 5 +++-- .../spark/deploy/rest/mesos/MesosRestServer.scala | 2 +- .../apache/spark/deploy/worker/WorkerArguments.scala | 2 +- core/src/main/scala/org/apache/spark/util/Utils.scala | 6 ++++++ .../spark/deploy/rest/SubmitRestProtocolSuite.scala | 10 +++++----- docs/configuration.md | 4 ++-- .../org/apache/spark/launcher/CommandBuilderUtils.java | 2 +- .../spark/launcher/SparkSubmitCommandBuilder.java | 2 +- .../spark/mllib/tree/model/DecisionTreeModel.scala | 2 +- .../spark/mllib/tree/model/treeEnsembleModels.scala | 2 +- .../java/org/apache/spark/network/util/JavaUtils.java | 6 ++++++ .../org/apache/spark/deploy/yarn/ClientArguments.scala | 7 ++++--- 15 files changed, 35 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 633b869f91784..86233e01db365 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -109,7 +109,7 @@ sparkR.init <- function( return(get(".sparkRjsc", envir = .sparkREnv)) } - sparkMem <- Sys.getenv("SPARK_MEM", "512m") + sparkMem <- Sys.getenv("SPARK_MEM", "1024m") jars <- suppressWarnings(normalizePath(as.character(sparkJars))) # Classpath separator is ";" on Windows diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 43c4288912b18..192d3ae091134 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -22,7 +22,7 @@ # - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) +# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) # - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 316e2d59f01b8..42d3296062e6d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -117,7 +117,7 @@ private[deploy] class ClientArguments(args: Array[String]) { private[deploy] object ClientArguments { val DEFAULT_CORES = 1 - val DEFAULT_MEMORY = 512 // MB + val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // MB val DEFAULT_SUPERVISE = false def isValidJarUrl(s: String): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index b7429a901e162..73ab18332feb4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -461,8 +461,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) outStream.println(command) + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB outStream.println( - """ + s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -488,7 +489,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --properties-file FILE Path to a file from which to load extra properties. If not | specified, this will look for conf/spark-defaults.conf. | - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M). + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: ${mem_mb}M). | --driver-java-options Extra Java options to pass to the driver. | --driver-library-path Extra library path entries to pass to the driver. | --driver-class-path Extra class path entries to pass to the driver. Note that diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 8198296eeb341..868cc35d06ef3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -59,7 +59,7 @@ private[mesos] class MesosSubmitRequestServlet( extends SubmitRequestServlet { private val DEFAULT_SUPERVISE = false - private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // mb private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 9678631da9f6f..1d2ecab517613 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -164,7 +164,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { } } // Leave out 1 GB for the operating system, but don't return a negative memory size - math.max(totalMb - 1024, 512) + math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB) } def checkWorkerMemory(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a7fc749a2b0c6..944560a91354a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -80,6 +80,12 @@ private[spark] object Utils extends Logging { */ val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + /** + * Define a default value for driver memory here since this value is referenced across the code + * base and nearly all files already use Utils.scala + */ + val DEFAULT_DRIVER_MEM_MB = JavaUtils.DEFAULT_DRIVER_MEM_MB.toInt + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 115ac0534a1b4..725b8848bc052 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.deploy.rest import java.lang.Boolean -import java.lang.Integer import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.Utils /** * Tests for the REST application submission protocol. @@ -93,7 +93,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { // optional fields conf.set("spark.jars", "mayonnaise.jar,ketchup.jar") conf.set("spark.files", "fireball.png") - conf.set("spark.driver.memory", "512m") + conf.set("spark.driver.memory", s"${Utils.DEFAULT_DRIVER_MEM_MB}m") conf.set("spark.driver.cores", "180") conf.set("spark.driver.extraJavaOptions", " -Dslices=5 -Dcolor=mostly_red") conf.set("spark.driver.extraClassPath", "food-coloring.jar") @@ -126,7 +126,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { assert(newMessage.sparkProperties("spark.app.name") === "SparkPie") assert(newMessage.sparkProperties("spark.jars") === "mayonnaise.jar,ketchup.jar") assert(newMessage.sparkProperties("spark.files") === "fireball.png") - assert(newMessage.sparkProperties("spark.driver.memory") === "512m") + assert(newMessage.sparkProperties("spark.driver.memory") === s"${Utils.DEFAULT_DRIVER_MEM_MB}m") assert(newMessage.sparkProperties("spark.driver.cores") === "180") assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === " -Dslices=5 -Dcolor=mostly_red") @@ -230,7 +230,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { """.stripMargin private val submitDriverRequestJson = - """ + s""" |{ | "action" : "CreateSubmissionRequest", | "appArgs" : [ "two slices", "a hint of cinnamon" ], @@ -246,7 +246,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { | "spark.driver.supervise" : "false", | "spark.app.name" : "SparkPie", | "spark.cores.max" : "10000", - | "spark.driver.memory" : "512m", + | "spark.driver.memory" : "${Utils.DEFAULT_DRIVER_MEM_MB}m", | "spark.files" : "fireball.png", | "spark.driver.cores" : "180", | "spark.driver.extraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red", diff --git a/docs/configuration.md b/docs/configuration.md index affcd21514d88..bebaf6f62e90a 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -137,10 +137,10 @@ of the most common options to set are: spark.driver.memory - 512m + 1g Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 512m, 2g). + (e.g. 1g, 2g).
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 2665a700fe1f5..a16c0d2b5ca0b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -27,7 +27,7 @@ */ class CommandBuilderUtils { - static final String DEFAULT_MEM = "512m"; + static final String DEFAULT_MEM = "1g"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 3e5a2820b6c11..87c43aa9980e1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -208,7 +208,7 @@ private List buildSparkSubmitCommand(Map env) throws IOE // - properties file. // - SPARK_DRIVER_MEMORY env variable // - SPARK_MEM env variable - // - default value (512m) + // - default value (1g) // Take Thrift Server as daemon String tsMemory = isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 25bb1453db404..f2c78bbabff0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -198,7 +198,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { val driverMemory = sc.getConf.getOption("spark.driver.memory") .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) .map(Utils.memoryStringToMb) - .getOrElse(512) + .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB) if (driverMemory <= memThreshold) { logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + s" driver memory (${driverMemory}m)." + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 1e3333d8d81d0..905c5fb42bd44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -387,7 +387,7 @@ private[tree] object TreeEnsembleModel extends Logging { val driverMemory = sc.getConf.getOption("spark.driver.memory") .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) .map(Utils.memoryStringToMb) - .getOrElse(512) + .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB) if (driverMemory <= memThreshold) { logWarning(s"$className.save() was called, but it may fail because of too little" + s" driver memory (${driverMemory}m)." + diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 6b514aaa1290d..7d27439cfde7a 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -39,6 +39,12 @@ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); + /** + * Define a default value for driver memory here since this value is referenced across the code + * base and nearly all files already use Utils.scala + */ + public static final long DEFAULT_DRIVER_MEM_MB = 1024; + /** Closes the given object, ignoring IOExceptions. */ public static void closeQuietly(Closeable closeable) { try { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 35e990602a6cf..19d1bbff9993f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -46,7 +46,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var keytab: String = null def isClusterMode: Boolean = userClass != null - private var driverMemory: Int = 512 // MB + private var driverMemory: Int = Utils.DEFAULT_DRIVER_MEM_MB // MB private var driverCores: Int = 1 private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead" private val amMemKey = "spark.yarn.am.memory" @@ -262,8 +262,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private def getUsageMessage(unknownParam: List[String] = null): String = { val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB message + - """ + s""" |Usage: org.apache.spark.deploy.yarn.Client [options] |Options: | --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster @@ -275,7 +276,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores per executor (Default: 1). - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb) + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: $mem_mb Mb) | --driver-cores NUM Number of cores used by the driver (Default: 1). | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) | --name NAME The name of your application (Default: Spark) From 1b0c8e61040bf06213f9758f775679dcc41b0cce Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Wed, 1 Jul 2015 23:14:13 -0700 Subject: [PATCH 097/274] [SPARK-8687] [YARN] Fix bug: Executor can't fetch the new set configuration in yarn-client Spark initi the properties CoarseGrainedSchedulerBackend.start ```scala // TODO (prashant) send conf instead of properties driverEndpoint = rpcEnv.setupEndpoint( CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) ``` Then the yarn logic will set some configuration but not update in this `properties`. So `Executor` won't gain the `properties`. [Jira](https://issues.apache.org/jira/browse/SPARK-8687) Author: huangzhaowei Closes #7066 from SaintBacchus/SPARK-8687 and squashes the following commits: 1de4f48 [huangzhaowei] Ensure all necessary properties have already been set before startup ExecutorLaucher --- .../scheduler/cluster/YarnClientSchedulerBackend.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index dd8c4fdb549ed..3a0b9443d2d7b 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -41,7 +41,6 @@ private[spark] class YarnClientSchedulerBackend( * This waits until the application is running. */ override def start() { - super.start() val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort @@ -56,6 +55,12 @@ private[spark] class YarnClientSchedulerBackend( totalExpectedExecutors = args.numExecutors client = new Client(args, conf) appId = client.submitApplication() + + // SPARK-8687: Ensure all necessary properties have already been set before + // we initialize our driver scheduler backend, which serves these properties + // to the executors + super.start() + waitForApplication() monitorThread = asyncMonitorApplication() monitorThread.start() From 41588365ad29408ccabd216b411e9c43f0053151 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 2 Jul 2015 21:16:35 +0900 Subject: [PATCH 098/274] [DOCS] Fix minor wrong lambda expression example. It's a really minor issue but there is an example with wrong lambda-expression usage in `SQLContext.scala` like as follows. ``` sqlContext.udf().register("myUDF", (Integer arg1, String arg2) -> arg2 + arg1), <- We have an extra `)` here. DataTypes.StringType); ``` Author: Kousuke Saruta Closes #7187 from sarutak/fix-minor-wrong-lambda-expression and squashes the following commits: a13196d [Kousuke Saruta] Fixed minor wrong lambda expression example. --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index fc14a77538ef1..e81371e7b0e83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -274,7 +274,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Or, to use Java 8 lambda syntax: * {{{ * sqlContext.udf().register("myUDF", - * (Integer arg1, String arg2) -> arg2 + arg1), + * (Integer arg1, String arg2) -> arg2 + arg1, * DataTypes.StringType); * }}} * From c572e25617f993c6b2e7d5f15f0fbf4426f89fab Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Thu, 2 Jul 2015 13:42:48 +0100 Subject: [PATCH 099/274] [SPARK-8787] [SQL] Changed parameter order of @deprecated in package object sql Parameter order of deprecated annotation in package object sql is wrong >>deprecated("1.3.0", "use DataFrame") . This has to be changed to deprecated("use DataFrame", "1.3.0") Author: Vinod K C Closes #7183 from vinodkc/fix_deprecated_param_order and squashes the following commits: 1cbdbe8 [Vinod K C] Modified the message 700911c [Vinod K C] Changed order of parameters --- sql/core/src/main/scala/org/apache/spark/sql/package.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 4e94fd07a8771..a9c600b139b18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -46,6 +46,6 @@ package object sql { * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. * @deprecated As of 1.3.0, replaced by `DataFrame`. */ - @deprecated("1.3.0", "use DataFrame") + @deprecated("use DataFrame", "1.3.0") type SchemaRDD = DataFrame } From 1bbdf9ead9e912f60dccbb23029b7de4948ebee3 Mon Sep 17 00:00:00 2001 From: Christian Kadner Date: Thu, 2 Jul 2015 13:45:19 +0100 Subject: [PATCH 100/274] [SPARK-8746] [SQL] update download link for Hive 0.13.1 updated the [Hive 0.13.1](https://archive.apache.org/dist/hive/hive-0.13.1) download link in `sql/README.md` Author: Christian Kadner Closes #7144 from ckadner/SPARK-8746 and squashes the following commits: 65d80f7 [Christian Kadner] [SPARK-8746][SQL] update download link for Hive 0.13.1 --- sql/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/README.md b/sql/README.md index 46aec7cef7984..63d4dac9829e0 100644 --- a/sql/README.md +++ b/sql/README.md @@ -25,7 +25,7 @@ export HADOOP_HOME="/hadoop-1.0.4" If you are working with Hive 0.13.1, the following steps are needed: -1. Download Hive's [0.13.1](https://hive.apache.org/downloads.html) and set `HIVE_HOME` with `export HIVE_HOME=""`. Please do not set `HIVE_DEV_HOME` (See [SPARK-4119](https://issues.apache.org/jira/browse/SPARK-4119)). +1. Download Hive's [0.13.1](https://archive.apache.org/dist/hive/hive-0.13.1) and set `HIVE_HOME` with `export HIVE_HOME=""`. Please do not set `HIVE_DEV_HOME` (See [SPARK-4119](https://issues.apache.org/jira/browse/SPARK-4119)). 2. Set `HADOOP_HOME` with `export HADOOP_HOME=""` 3. Download all Hive 0.13.1a jars (Hive jars actually used by Spark) from [here](http://mvnrepository.com/artifact/org.spark-project.hive) and replace corresponding original 0.13.1 jars in `$HIVE_HOME/lib`. 4. Download [Kryo 2.21 jar](http://mvnrepository.com/artifact/com.esotericsoftware.kryo/kryo/2.21) (Note: 2.22 jar does not work) and [Javolution 5.5.1 jar](http://mvnrepository.com/artifact/javolution/javolution/5.5.1) to `$HIVE_HOME/lib`. From 246265f2bb056d5e9011d3331b809471a24ff8d7 Mon Sep 17 00:00:00 2001 From: Wisely Chen Date: Thu, 2 Jul 2015 09:58:12 -0700 Subject: [PATCH 101/274] [SPARK-8690] [SQL] Add a setting to disable SparkSQL parquet schema merge by using datasource API The detail problem story is in https://issues.apache.org/jira/browse/SPARK-8690 General speaking, I add a config spark.sql.parquet.mergeSchema to achieve the sqlContext.load("parquet" , Map( "path" -> "..." , "mergeSchema" -> "false" )) It will become a simple flag and without any side affect. Author: Wisely Chen Closes #7070 from thegiive/SPARK8690 and squashes the following commits: c6f3e86 [Wisely Chen] Refactor some code style and merge the test case to ParquetSchemaMergeConfigSuite 94c9307 [Wisely Chen] Remove some style problem db8ef1b [Wisely Chen] Change config to SQLConf and add test case b6806fb [Wisely Chen] remove text c0edb8c [Wisely Chen] [SPARK-8690] add a config spark.sql.parquet.mergeSchema to disable datasource API schema merge feature. --- .../scala/org/apache/spark/sql/SQLConf.scala | 6 ++++++ .../apache/spark/sql/parquet/newParquet.scala | 5 ++++- .../spark/sql/parquet/ParquetQuerySuite.scala | 20 +++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9a10a23937fbb..2c258b6ee399c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -227,6 +227,12 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "") + val PARQUET_SCHEMA_MERGING_ENABLED = booleanConf("spark.sql.parquet.mergeSchema", + defaultValue = Some(true), + doc = "When true, the Parquet data source merges schemas collected from all data files, " + + "otherwise the schema is picked from the summary file or a random data file " + + "if no summary file is available.") + val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", defaultValue = Some(false), doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index bc39fae2bcfde..5ac3e9a44e6fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -114,7 +114,10 @@ private[sql] class ParquetRelation2( // Should we merge schemas from all Parquet part-files? private val shouldMergeSchemas = - parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean + parameters + .get(ParquetRelation2.MERGE_SCHEMA) + .map(_.toBoolean) + .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) private val maybeMetastoreSchema = parameters .get(ParquetRelation2.METASTORE_SCHEMA) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index fafad67dde3a7..a0a81c4309c0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.types._ @@ -122,6 +123,25 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { checkAnswer(df2, df.collect().toSeq) } } + + test("Enabling/disabling schema merging") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + testSchemaMerging(3) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { + testSchemaMerging(2) + } + } } class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { From 99c40cd0d8465525cac34dfa373b81532ef3d719 Mon Sep 17 00:00:00 2001 From: Alok Singh Date: Thu, 2 Jul 2015 09:58:57 -0700 Subject: [PATCH 102/274] [SPARK-8647] [MLLIB] Potential issue with constant hashCode I added the code, // see [SPARK-8647], this achieves the needed constant hash code without constant no. override def hashCode(): Int = this.getClass.getName.hashCode() does getting the constant hash code as per jira Author: Alok Singh Closes #7146 from aloknsingh/aloknsingh_SPARK-8647 and squashes the following commits: e58bccf [Alok Singh] [SPARK-8647][MLlib] to avoid the class derivation issues, change the constant hashCode to override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode() 43cdb89 [Alok Singh] [SPARK-8647][MLlib] Potential issue with constant hashCode --- .../main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 3 ++- .../src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 85e63b1382b5e..0a615494bb2d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -193,7 +193,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { } } - override def hashCode(): Int = 1994 + // see [SPARK-8647], this achieves the needed constant hash code without constant no. + override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode() override def typeName: String = "matrix" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 2ffa497a99d93..c9c27425d2877 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -234,7 +234,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { } } - override def hashCode: Int = 7919 + // see [SPARK-8647], this achieves the needed constant hash code without constant no. + override def hashCode(): Int = classOf[VectorUDT].getName.hashCode() override def typeName: String = "vector" From 0a468a46bf5b905e9b0205e98b862570b2ac556e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 2 Jul 2015 09:59:54 -0700 Subject: [PATCH 103/274] [SPARK-8758] [MLLIB] Add Python user guide for PowerIterationClustering Add Python user guide for PowerIterationClustering Author: Yanbo Liang Closes #7155 from yanboliang/spark-8758 and squashes the following commits: 18d803b [Yanbo Liang] address comments dd29577 [Yanbo Liang] Add Python user guide for PowerIterationClustering --- data/mllib/pic_data.txt | 19 ++++++++++++++ docs/mllib-clustering.md | 54 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 data/mllib/pic_data.txt diff --git a/data/mllib/pic_data.txt b/data/mllib/pic_data.txt new file mode 100644 index 0000000000000..fcfef8cd19131 --- /dev/null +++ b/data/mllib/pic_data.txt @@ -0,0 +1,19 @@ +0 1 1.0 +0 2 1.0 +0 3 1.0 +1 2 1.0 +1 3 1.0 +2 3 1.0 +3 4 0.1 +4 5 1.0 +4 15 1.0 +5 6 1.0 +6 7 1.0 +7 8 1.0 +8 9 1.0 +9 10 1.0 +10 11 1.0 +11 12 1.0 +12 13 1.0 +13 14 1.0 +14 15 1.0 diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index dcaa3784be874..3aad4149f99db 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -327,11 +327,17 @@ which contains the computed clustering assignments. import org.apache.spark.mllib.clustering.{PowerIterationClustering, PowerIterationClusteringModel} import org.apache.spark.mllib.linalg.Vectors -val similarities: RDD[(Long, Long, Double)] = ... +// Load and parse the data +val data = sc.textFile("data/mllib/pic_data.txt") +val similarities = data.map { line => + val parts = line.split(' ') + (parts(0).toLong, parts(1).toLong, parts(2).toDouble) +} +// Cluster the data into two classes using PowerIterationClustering val pic = new PowerIterationClustering() - .setK(3) - .setMaxIterations(20) + .setK(2) + .setMaxIterations(10) val model = pic.run(similarities) model.assignments.foreach { a => @@ -363,11 +369,22 @@ import scala.Tuple2; import scala.Tuple3; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.PowerIterationClustering; import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; -JavaRDD> similarities = ... +// Load and parse the data +JavaRDD data = sc.textFile("data/mllib/pic_data.txt"); +JavaRDD> similarities = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(" "); + return new Tuple3<>(new Long(parts[0]), new Long(parts[1]), new Double(parts[2])); + } + } +); +// Cluster the data into two classes using PowerIterationClustering PowerIterationClustering pic = new PowerIterationClustering() .setK(2) .setMaxIterations(10); @@ -383,6 +400,35 @@ PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc. {% endhighlight %} +
+ +[`PowerIterationClustering`](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering) +implements the PIC algorithm. +It takes an `RDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering), +which contains the computed clustering assignments. + +{% highlight python %} +from __future__ import print_function +from pyspark.mllib.clustering import PowerIterationClustering, PowerIterationClusteringModel + +# Load and parse the data +data = sc.textFile("data/mllib/pic_data.txt") +similarities = data.map(lambda line: tuple([float(x) for x in line.split(' ')])) + +# Cluster the data into two classes using PowerIterationClustering +model = PowerIterationClustering.train(similarities, 2, 10) + +model.assignments().foreach(lambda x: print(str(x.id) + " -> " + str(x.cluster))) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") +{% endhighlight %} +
+ ## Latent Dirichlet allocation (LDA) From 5b3338130dfd9db92c4894a348839a62ebb57ef3 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Thu, 2 Jul 2015 10:02:19 -0700 Subject: [PATCH 104/274] [SPARK-8223] [SPARK-8224] [SQL] shift left and shift right Jira: https://issues.apache.org/jira/browse/SPARK-8223 https://issues.apache.org/jira/browse/SPARK-8224 ~~I am aware of #7174 and will update this pr, if it's merged.~~ Done I don't know if #7034 can simplify this, but we can have a look on it, if it gets merged rxin In the Jira ticket the function as no second argument. I added a `numBits` argument that allows to specify the number of bits. I guess this improves the usability. I wanted to add `shiftleft(value)` as well, but the `selectExpr` dataframe tests crashes, if I have both. I order to do this, I added the following to the functions.scala `def shiftRight(e: Column): Column = ShiftRight(e.expr, lit(1).expr)`, but as I mentioned this doesn't pass tests like `df.selectExpr("shiftRight(a)", ...` (not enough arguments exception). If we need the bitwise shift in order to be hive compatible, I suggest to add `shiftLeft` and something like `shiftLeftX` Author: Tarek Auel Closes #7178 from tarekauel/8223 and squashes the following commits: 8023bb5 [Tarek Auel] [SPARK-8223][SPARK-8224] fixed test f3f64e6 [Tarek Auel] [SPARK-8223][SPARK-8224] Integer -> Int f628706 [Tarek Auel] [SPARK-8223][SPARK-8224] removed toString; updated function description 3b56f2a [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223 5189690 [Tarek Auel] [SPARK-8223][SPARK-8224] minor fix and style fix 9434a28 [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223 44ee324 [Tarek Auel] [SPARK-8223][SPARK-8224] docu fix ac7fe9d [Tarek Auel] [SPARK-8223][SPARK-8224] right and left bit shift --- python/pyspark/sql/functions.py | 24 +++++ .../catalyst/analysis/FunctionRegistry.scala | 2 + .../spark/sql/catalyst/expressions/math.scala | 98 +++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 28 +++++- .../org/apache/spark/sql/functions.scala | 38 +++++++ .../spark/sql/MathExpressionsSuite.scala | 34 +++++++ 6 files changed, 223 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9a15d4a66309..bccde6083ca3c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -412,6 +412,30 @@ def sha2(col, numBits): return Column(jc) +@since(1.5) +def shiftLeft(col, numBits): + """Shift the the given value numBits left. + + >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() + [Row(r=42)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits) + return Column(jc) + + +@since(1.5) +def shiftRight(col, numBits): + """Shift the the given value numBits right. + + >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() + [Row(r=21)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6f04298d4711b..aa051b163363a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -125,6 +125,8 @@ object FunctionRegistry { expression[Pow]("power"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[ShiftLeft]("shiftleft"), + expression[ShiftRight]("shiftright"), expression[Signum]("sign"), expression[Signum]("signum"), expression[Sin]("sin"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 8633eb06ffee4..7504c6a066657 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -351,6 +351,104 @@ case class Pow(left: Expression, right: Expression) } } +case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression { + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess + case (_, IntegerType) => left.dataType match { + case LongType | IntegerType | ShortType | ByteType => + return TypeCheckResult.TypeCheckSuccess + case _ => // failed + } + case _ => // failed + } + TypeCheckResult.TypeCheckFailure( + s"ShiftLeft expects long, integer, short or byte value as first argument and an " + + s"integer value as second argument, not (${left.dataType}, ${right.dataType})") + } + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: Long => l << valueRight.asInstanceOf[Integer] + case i: Integer => i << valueRight.asInstanceOf[Integer] + case s: Short => s << valueRight.asInstanceOf[Integer] + case b: Byte => b << valueRight.asInstanceOf[Integer] + } + } else { + null + } + } else { + null + } + } + + override def dataType: DataType = { + left.dataType match { + case LongType => LongType + case IntegerType | ShortType | ByteType => IntegerType + case _ => NullType + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;") + } +} + +case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression { + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess + case (_, IntegerType) => left.dataType match { + case LongType | IntegerType | ShortType | ByteType => + return TypeCheckResult.TypeCheckSuccess + case _ => // failed + } + case _ => // failed + } + TypeCheckResult.TypeCheckFailure( + s"ShiftRight expects long, integer, short or byte value as first argument and an " + + s"integer value as second argument, not (${left.dataType}, ${right.dataType})") + } + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: Long => l >> valueRight.asInstanceOf[Integer] + case i: Integer => i >> valueRight.asInstanceOf[Integer] + case s: Short => s >> valueRight.asInstanceOf[Integer] + case b: Byte => b >> valueRight.asInstanceOf[Integer] + } + } else { + null + } + } else { + null + } + } + + override def dataType: DataType = { + left.dataType match { + case LongType => LongType + case IntegerType | ShortType | ByteType => IntegerType + case _ => NullType + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;") + } +} + /** * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index b3345d7069159..aa27fe3cd5564 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DataType, DoubleType, LongType} +import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType} class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -225,6 +225,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) } + test("shift left") { + checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42) + checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42) + checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42) + checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) + + checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong) + } + + test("shift right") { + checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21) + checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21) + checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21) + checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) + + checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) + } + test("hex") { checkEvaluation(Hex(Literal(28)), "1C") checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e6f623bdf39eb..a5b68286853ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1298,6 +1298,44 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(columnName: String, numBits: Int): Column = + shiftLeft(Column(columnName), numBits) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(columnName: String, numBits: Int): Column = + shiftRight(Column(columnName), numBits) + /** * Computes the signum of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index c03cde38d75d0..4c5696deaff81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -259,6 +259,40 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("shift left") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1), + shiftLeft('f, 1)), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)", + "shiftLeft(f, 1)"), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + } + + test("shift right") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1), + shiftRight('f, 1)), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)", + "shiftRight(f, 1)"), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + } + test("binary log") { val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") checkAnswer( From afa021e03f0a1a326be2ed742332845b77f94c55 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 2 Jul 2015 10:06:38 -0700 Subject: [PATCH 105/274] [SPARK-8747] [SQL] fix EqualNullSafe for binary type also improve tests for binary comparison. Author: Wenchen Fan Closes #7143 from cloud-fan/binary and squashes the following commits: 28a5b76 [Wenchen Fan] improve test 04ef4b0 [Wenchen Fan] fix equalNullSafe --- .../sql/catalyst/expressions/predicates.scala | 3 +- .../catalyst/expressions/PredicateSuite.scala | 122 +++++++++++------- 2 files changed, 78 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 34df89a163895..d4569241e7364 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -302,7 +302,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (l == null || r == null) { false } else { - l == r + if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 72fec3b86e5e4..188ecef9e7679 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Date, Timestamp} - import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{IntegerType, BooleanType} +import org.apache.spark.sql.types.{Decimal, IntegerType, BooleanType} class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -66,12 +63,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { * Unknown Unknown */ // scalastyle:on - val notTrueTable = - (true, false) :: - (false, true) :: - (null, null) :: Nil test("3VL Not") { + val notTrueTable = + (true, false) :: + (false, true) :: + (null, null) :: Nil notTrueTable.foreach { case (v, answer) => checkEvaluation(Not(Literal.create(v, BooleanType)), answer) } @@ -126,8 +123,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val two = Literal(2) val three = Literal(3) val nl = Literal(null) - val s = Seq(one, two) - val nullS = Seq(one, two, null) checkEvaluation(InSet(one, hS), true) checkEvaluation(InSet(two, hS), true) checkEvaluation(InSet(two, nS), true) @@ -137,43 +132,78 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) } + private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) + private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_)) - test("BinaryComparison") { - val row = create_row(1, 2, 3, null, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - val c5 = 'a.int.at(4) - val c6 = 'a.int.at(5) + private val equalValues1 = smallValues + private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) - checkEvaluation(LessThan(c1, c4), null, row) - checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(c1 < c2, true, row) - checkEvaluation(c1 <= c2, true, row) - checkEvaluation(c1 > c2, false, row) - checkEvaluation(c1 >= c2, false, row) - checkEvaluation(c1 === c2, false, row) - checkEvaluation(c1 !== c2, true, row) - checkEvaluation(c4 <=> c1, false, row) - checkEvaluation(c1 <=> c4, false, row) - checkEvaluation(c4 <=> c6, true, row) - checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) - checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) - - val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01")) - val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-02")) - checkEvaluation(Literal(d1) < Literal(d2), true) - - val ts1 = new Timestamp(12) - val ts2 = new Timestamp(123) - checkEvaluation(Literal("ab") < Literal("abc"), true) - checkEvaluation(Literal(ts1) < Literal(ts2), true) + test("BinaryComparison: <") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) < largeValues(i), true) + checkEvaluation(equalValues1(i) < equalValues2(i), false) + checkEvaluation(largeValues(i) < smallValues(i), false) + } + } + + test("BinaryComparison: <=") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) <= largeValues(i), true) + checkEvaluation(equalValues1(i) <= equalValues2(i), true) + checkEvaluation(largeValues(i) <= smallValues(i), false) + } + } + + test("BinaryComparison: >") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) > largeValues(i), false) + checkEvaluation(equalValues1(i) > equalValues2(i), false) + checkEvaluation(largeValues(i) > smallValues(i), true) + } + } + + test("BinaryComparison: >=") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) >= largeValues(i), false) + checkEvaluation(equalValues1(i) >= equalValues2(i), true) + checkEvaluation(largeValues(i) >= smallValues(i), true) + } + } + + test("BinaryComparison: ===") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) === largeValues(i), false) + checkEvaluation(equalValues1(i) === equalValues2(i), true) + checkEvaluation(largeValues(i) === smallValues(i), false) + } + } + + test("BinaryComparison: <=>") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) <=> largeValues(i), false) + checkEvaluation(equalValues1(i) <=> equalValues2(i), true) + checkEvaluation(largeValues(i) <=> smallValues(i), false) + } + } + + test("BinaryComparison: null test") { + val normalInt = Literal(1) + val nullInt = Literal.create(null, IntegerType) + + def nullTest(op: (Expression, Expression) => Expression): Unit = { + checkEvaluation(op(normalInt, nullInt), null) + checkEvaluation(op(nullInt, normalInt), null) + checkEvaluation(op(nullInt, nullInt), null) + } + + nullTest(LessThan) + nullTest(LessThanOrEqual) + nullTest(GreaterThan) + nullTest(GreaterThanOrEqual) + nullTest(EqualTo) + + checkEvaluation(normalInt <=> nullInt, false) + checkEvaluation(nullInt <=> normalInt, false) + checkEvaluation(nullInt <=> nullInt, true) } } From 52302a803967114b29a8bf6b74459477364c5b88 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 2 Jul 2015 10:12:25 -0700 Subject: [PATCH 106/274] [SPARK-8407] [SQL] complex type constructors: struct and named_struct This is a follow up of [SPARK-8283](https://issues.apache.org/jira/browse/SPARK-8283) ([PR-6828](https://github.com/apache/spark/pull/6828)), to support both `struct` and `named_struct` in Spark SQL. After [#6725](https://github.com/apache/spark/pull/6828), the semantic of [`CreateStruct`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala#L56) methods have changed a little and do not limited to cols of `NamedExpressions`, it will name non-NamedExpression fields following the hive convention, col1, col2 ... This PR would both loosen [`struct`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L723) to take children of `Expression` type and add `named_struct` support. Author: Yijie Shen Closes #6874 from yijieshen/SPARK-8283 and squashes the following commits: 4cd3375ac [Yijie Shen] change struct documentation d599d0b [Yijie Shen] rebase code 9a7039e [Yijie Shen] fix reviews and regenerate golden answers b487354 [Yijie Shen] replace assert using checkAnswer f07e114 [Yijie Shen] tiny fix 9613be9 [Yijie Shen] review fix 7fef712 [Yijie Shen] Fix checkInputTypes' implementation using foldable and nullable 60812a7 [Yijie Shen] Fix type check 828d694 [Yijie Shen] remove unnecessary resolved assertion inside dataType method fd3cd8e [Yijie Shen] remove type check from eval 7a71255 [Yijie Shen] tiny fix ccbbd86 [Yijie Shen] Fix reviews 47da332 [Yijie Shen] remove nameStruct API from DataFrame 917e680 [Yijie Shen] Fix reviews 4bd75ad [Yijie Shen] loosen struct method in functions.scala to take Expression children 0acb7be [Yijie Shen] Add CreateNamedStruct in both DataFrame function API and FunctionRegistery --- python/pyspark/sql/functions.py | 1 - .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/complexTypeCreator.scala | 49 +++++++++++++++++++ .../ExpressionTypeCheckingSuite.scala | 11 +++++ .../expressions/ComplexTypeSuite.scala | 24 ++++++++- .../org/apache/spark/sql/functions.scala | 11 +++-- .../spark/sql/DataFrameFunctionsSuite.scala | 40 +++++++++++++-- ...ic udf-0-638f81ad9077c7d0c5c735c6e73742ad} | 0 .../sql/hive/execution/HiveQuerySuite.scala | 2 +- 9 files changed, 126 insertions(+), 13 deletions(-) rename sql/hive/src/test/resources/golden/{constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 => constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad} (100%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bccde6083ca3c..12263e6a75af8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -467,7 +467,6 @@ def struct(*cols): """Creates a new struct column. :param cols: list of column names (string) or list of :class:`Column` expressions - that are named or aliased. >>> df.select(struct('age', 'name').alias("struct")).collect() [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aa051b163363a..e7e4d1c4efe18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -96,6 +96,7 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), + expression[CreateNamedStruct]("named_struct"), expression[Sqrt]("sqrt"), // math functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67e7dc4ec8b14..fa70409353e79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. @@ -54,6 +57,8 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) + override lazy val resolved: Boolean = childrenResolved + override lazy val dataType: StructType = { val fields = children.zipWithIndex.map { case (child, idx) => child match { @@ -74,3 +79,47 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def prettyName: String = "struct" } + +/** + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") + } else { + val invalidNames = + nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) + if (invalidNames.size != 0) { + TypeCheckResult.TypeCheckFailure( + s"Odd position only allow foldable and not-null StringType expressions, got :" + + s" ${invalidNames.mkString(",")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + } + + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index bc1537b0715b5..8e0551b23eea6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -160,4 +160,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Explode('intField), "input to function explode should be array or map type") } + + test("check types for CreateNamedStruct") { + assertError( + CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") + assertError( + CreateNamedStruct(Seq(1, "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") + assertError( + CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 3515d044b2f7e..a09014e1ffc15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -119,11 +121,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("CreateStruct") { val row = create_row(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) } + test("CreateNamedStruct") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row) + } + + test("CreateNamedStruct with literal field") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row) + } + + test("CreateNamedStruct from all literal fields") { + checkEvaluation( + CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + } + test("test dsl for complex type") { def quickResolve(u: UnresolvedExtractValue): Expression = { ExtractValue(u.child, u.extraction, _ == _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a5b68286853ed..4ee1fb8374b07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -739,17 +739,18 @@ object functions { def sqrt(colName: String): Column = sqrt(Column(colName)) /** - * Creates a new struct column. The input column must be a column in a [[DataFrame]], or - * a derived column expression that is named (i.e. aliased). + * Creates a new struct column. + * If the input column is a column in a [[DataFrame]], or a derived column expression + * that is named (i.e. aliased), its name would be remained as the StructField's name, + * otherwise, the newly generated StructField's name would be auto generated as col${index + 1}, + * i.e. col1, col2, col3, ... * * @group normal_funcs * @since 1.4.0 */ @scala.annotation.varargs def struct(cols: Column*): Column = { - require(cols.forall(_.expr.isInstanceOf[NamedExpression]), - s"struct input columns must all be named or aliased ($cols)") - CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + CreateStruct(cols.map(_.expr)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7ae89bcb1b9cf..0d43aca877f68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -79,10 +79,42 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row.getAs[Row](0) === Row(2, "str")) } - test("struct: must use named column expression") { - intercept[IllegalArgumentException] { - struct(col("a") * 2) - } + test("struct with column expression to be automatically named") { + val df = Seq((1, "str")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), col("b"))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Row(Row(2, "str"))) + } + + test("struct with literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0)))) + } + + test("struct with all literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct(lit("v"), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", StringType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0)))) } test("constant functions") { diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad similarity index 100% rename from sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 rename to sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4cdba03b27022..991da2f829ae5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -132,7 +132,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { lower("AA"), "10", repeat(lower("AA"), 3), "11", lower(repeat("AA", 3)), "12", - printf("Bb%d", 12), "13", + printf("bb%d", 12), "13", repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""") createQueryTest("NaN to Decimal", From 0e553a3e9360a736920e2214d634373fef0dbcf7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Jul 2015 10:18:23 -0700 Subject: [PATCH 107/274] [SPARK-8708] [MLLIB] Paritition ALS ratings based on both users and products JIRA: https://issues.apache.org/jira/browse/SPARK-8708 Previously the partitions of ratings are only based on the given products. So if the `usersProducts` given for prediction contains only few products or even one product, the generated ratings will be pushed into few or single partition and can't use high parallelism. The following codes are the example reported in the JIRA. Because it asks the predictions for users on product 2. There is only one partition in the result. >>> r1 = (1, 1, 1.0) >>> r2 = (1, 2, 2.0) >>> r3 = (2, 1, 2.0) >>> r4 = (2, 2, 2.0) >>> r5 = (3, 1, 1.0) >>> ratings = sc.parallelize([r1, r2, r3, r4, r5], 5) >>> users = ratings.map(itemgetter(0)).distinct() >>> model = ALS.trainImplicit(ratings, 1, seed=10) >>> predictions_for_2 = model.predictAll(users.map(lambda u: (u, 2))) >>> predictions_for_2.glom().map(len).collect() [0, 0, 3, 0, 0] This PR uses user and product instead of only product to partition the ratings. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #7121 from viirya/mfm_fix_partition and squashes the following commits: 779946d [Liang-Chi Hsieh] Calculate approximate numbers of users and products in one pass. 4336dc2 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into mfm_fix_partition 83e56c1 [Liang-Chi Hsieh] Instead of additional join, use the numbers of users and products to decide how to perform join. b534dc8 [Liang-Chi Hsieh] Paritition ratings based on both users and products. --- .../MatrixFactorizationModel.scala | 55 +++++++++++++++++-- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 93aa41e49961e..43d219a49cf4e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -22,6 +22,7 @@ import java.lang.{Integer => JavaInteger} import scala.collection.mutable +import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.Path import org.json4s._ @@ -79,6 +80,30 @@ class MatrixFactorizationModel( blas.ddot(rank, userVector, 1, productVector, 1) } + /** + * Return approximate numbers of users and products in the given usersProducts tuples. + * This method is based on `countApproxDistinct` in class `RDD`. + * + * @param usersProducts RDD of (user, product) pairs. + * @return approximate numbers of users and products. + */ + private[this] def countApproxDistinctUserProduct(usersProducts: RDD[(Int, Int)]): (Long, Long) = { + val zeroCounterUser = new HyperLogLogPlus(4, 0) + val zeroCounterProduct = new HyperLogLogPlus(4, 0) + val aggregated = usersProducts.aggregate((zeroCounterUser, zeroCounterProduct))( + (hllTuple: (HyperLogLogPlus, HyperLogLogPlus), v: (Int, Int)) => { + hllTuple._1.offer(v._1) + hllTuple._2.offer(v._2) + hllTuple + }, + (h1: (HyperLogLogPlus, HyperLogLogPlus), h2: (HyperLogLogPlus, HyperLogLogPlus)) => { + h1._1.addAll(h2._1) + h1._2.addAll(h2._2) + h1 + }) + (aggregated._1.cardinality(), aggregated._2.cardinality()) + } + /** * Predict the rating of many users for many products. * The output RDD has an element per each element in the input RDD (including all duplicates) @@ -88,12 +113,30 @@ class MatrixFactorizationModel( * @return RDD of Ratings. */ def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { - val users = userFeatures.join(usersProducts).map { - case (user, (uFeatures, product)) => (product, (user, uFeatures)) - } - users.join(productFeatures).map { - case (product, ((user, uFeatures), pFeatures)) => - Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + // Previously the partitions of ratings are only based on the given products. + // So if the usersProducts given for prediction contains only few products or + // even one product, the generated ratings will be pushed into few or single partition + // and can't use high parallelism. + // Here we calculate approximate numbers of users and products. Then we decide the + // partitions should be based on users or products. + val (usersCount, productsCount) = countApproxDistinctUserProduct(usersProducts) + + if (usersCount < productsCount) { + val users = userFeatures.join(usersProducts).map { + case (user, (uFeatures, product)) => (product, (user, uFeatures)) + } + users.join(productFeatures).map { + case (product, ((user, uFeatures), pFeatures)) => + Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + } + } else { + val products = productFeatures.join(usersProducts.map(_.swap)).map { + case (product, (pFeatures, user)) => (user, (product, pFeatures)) + } + products.join(userFeatures).map { + case (user, ((product, pFeatures), uFeatures)) => + Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + } } } From 2e2f32603c110b9c6ddfbb836f63882eacf0a8cc Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 2 Jul 2015 10:57:02 -0700 Subject: [PATCH 108/274] [SPARK-8581] [SPARK-8584] Simplify checkpointing code + better error message This patch rewrites the old checkpointing code in a way that is easier to understand. It also adds a guard against an invalid specification of checkpoint directory to provide a clearer error message. Most of the changes here are relatively minor. Author: Andrew Or Closes #6968 from andrewor14/checkpoint-cleanup and squashes the following commits: 4ef8263 [Andrew Or] Use global synchronized instead 6f6fd84 [Andrew Or] Merge branch 'master' of github.com:apache/spark into checkpoint-cleanup b1437ad [Andrew Or] Warn instead of throw 5484293 [Andrew Or] Merge branch 'master' of github.com:apache/spark into checkpoint-cleanup 7fb4af5 [Andrew Or] Guard against bad settings of checkpoint directory 691da98 [Andrew Or] Simplify checkpoint code / code style / comments --- .../scala/org/apache/spark/SparkContext.scala | 10 +++ .../org/apache/spark/rdd/CheckpointRDD.scala | 17 +++-- .../main/scala/org/apache/spark/rdd/RDD.scala | 14 ++-- .../apache/spark/rdd/RDDCheckpointData.scala | 71 +++++++++---------- .../org/apache/spark/CheckpointSuite.scala | 2 +- 5 files changed, 60 insertions(+), 54 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0e5a86f44e410..8eed46759f340 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1906,6 +1906,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * be a HDFS path if running on a cluster. */ def setCheckpointDir(directory: String) { + + // If we are running on a cluster, log a warning if the directory is local. + // Otherwise, the driver may attempt to reconstruct the checkpointed RDD from + // its own local file system, which is incorrect because the checkpoint files + // are actually on the executor machines. + if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) { + logWarning("Checkpoint directory must be non-local " + + "if Spark is running on a cluster: " + directory) + } + checkpointDir = Option(directory).map { dir => val path = new Path(dir, UUID.randomUUID().toString) val fs = path.getFileSystem(hadoopConfiguration) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 33e6998b2cb10..e17bd47905d7a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.{SerializableConfiguration, Utils} -private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} +private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** * This RDD represents a RDD checkpoint file (similar to HadoopRDD). @@ -37,9 +37,11 @@ private[spark] class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) + private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) - @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) + @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) + + override def getCheckpointFile: Option[String] = Some(checkpointPath) override def getPartitions: Array[Partition] = { val cpath = new Path(checkpointPath) @@ -59,9 +61,6 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) } - checkpointData = Some(new RDDCheckpointData[T](this)) - checkpointData.get.cpFile = Some(checkpointPath) - override def getPreferredLocations(split: Partition): Seq[String] = { val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))) @@ -74,9 +73,9 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) CheckpointRDD.readFromFile(file, broadcastedConf, context) } - override def checkpoint() { - // Do nothing. CheckpointRDD should not be checkpointed. - } + // CheckpointRDD should not be checkpointed again + override def checkpoint(): Unit = { } + override def doCheckpoint(): Unit = { } } private[spark] object CheckpointRDD extends Logging { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cac6e3b477e16..9f7ebae3e9af3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -194,7 +194,7 @@ abstract class RDD[T: ClassTag]( @transient private var partitions_ : Array[Partition] = null /** An Option holding our checkpoint RDD, if we are checkpointed */ - private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) /** * Get the list of dependencies of this RDD, taking into account whether the @@ -1451,12 +1451,16 @@ abstract class RDD[T: ClassTag]( * executed on this RDD. It is strongly recommended that this RDD is persisted in * memory, otherwise saving it on a file will require recomputation. */ - def checkpoint() { + def checkpoint(): Unit = { if (context.checkpointDir.isEmpty) { throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) - checkpointData.get.markForCheckpoint() + // NOTE: we use a global lock here due to complexities downstream with ensuring + // children RDD partitions point to the correct parent partitions. In the future + // we should revisit this consideration. + RDDCheckpointData.synchronized { + checkpointData = Some(new RDDCheckpointData(this)) + } } } @@ -1497,7 +1501,7 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassTag] = { + protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index acbd31aacdf59..4f954363bed8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -22,16 +22,15 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.spark._ -import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} import org.apache.spark.util.SerializableConfiguration /** * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + * [ Initialized --> checkpointing in progress --> checkpointed ]. */ private[spark] object CheckpointState extends Enumeration { type CheckpointState = Value - val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value + val Initialized, CheckpointingInProgress, Checkpointed = Value } /** @@ -46,37 +45,37 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) import CheckpointState._ // The checkpoint state of the associated RDD. - var cpState = Initialized + private var cpState = Initialized // The file to which the associated RDD has been checkpointed to - @transient var cpFile: Option[String] = None + private var cpFile: Option[String] = None // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - var cpRDD: Option[RDD[T]] = None + // This is defined if and only if `cpState` is `Checkpointed`. + private var cpRDD: Option[CheckpointRDD[T]] = None - // Mark the RDD for checkpointing - def markForCheckpoint() { - RDDCheckpointData.synchronized { - if (cpState == Initialized) cpState = MarkedForCheckpoint - } - } + // TODO: are we sure we need to use a global lock in the following methods? // Is the RDD already checkpointed - def isCheckpointed: Boolean = { - RDDCheckpointData.synchronized { cpState == Checkpointed } + def isCheckpointed: Boolean = RDDCheckpointData.synchronized { + cpState == Checkpointed } // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = { - RDDCheckpointData.synchronized { cpFile } + def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized { + cpFile } - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. - def doCheckpoint() { - // If it is marked for checkpointing AND checkpointing is not already in progress, - // then set it to be in progress, else return + /** + * Materialize this RDD and write its content to a reliable DFS. + * This is called immediately after the first action invoked on this RDD has completed. + */ + def doCheckpoint(): Unit = { + + // Guard against multiple threads checkpointing the same RDD by + // atomically flipping the state of this RDDCheckpointData RDDCheckpointData.synchronized { - if (cpState == MarkedForCheckpoint) { + if (cpState == Initialized) { cpState = CheckpointingInProgress } else { return @@ -87,7 +86,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) + throw new SparkException(s"Failed to create checkpoint path $path") } // Save to file, and reload it as an RDD @@ -99,6 +98,8 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) } } + + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( @@ -113,34 +114,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) - } - - // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Partition): Seq[String] = { - RDDCheckpointData.synchronized { - cpRDD.get.preferredLocations(split) - } + logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}") } - def getPartitions: Array[Partition] = { - RDDCheckpointData.synchronized { - cpRDD.get.partitions - } + def getPartitions: Array[Partition] = RDDCheckpointData.synchronized { + cpRDD.get.partitions } - def checkpointRDD: Option[RDD[T]] = { - RDDCheckpointData.synchronized { - cpRDD - } + def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { + cpRDD } } private[spark] object RDDCheckpointData { + + /** Return the path of the directory to which this RDD's checkpoint data is written. */ def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = { - sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) } + sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") } } + /** Clean up the files associated with the checkpoint data for this RDD. */ def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = { rddCheckpointDataPath(sc, rddId).foreach { path => val fs = path.getFileSystem(sc.hadoopConfiguration) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d1761a48babbc..cc50e6d79a3e2 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -46,7 +46,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) flatMappedRDD.checkpoint() - assert(flatMappedRDD.dependencies.head.rdd == parCollection) + assert(flatMappedRDD.dependencies.head.rdd === parCollection) val result = flatMappedRDD.collect() assert(flatMappedRDD.dependencies.head.rdd != parCollection) assert(flatMappedRDD.collect() === result) From 34d448dbe1d7bd5bf9a8d6ef473878e570ca6161 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 2 Jul 2015 11:28:14 -0700 Subject: [PATCH 109/274] [SPARK-8479] [MLLIB] Add numNonzeros and numActives to linalg.Matrices Matrices allow zeros to be stored in values. Sometimes a method is handy to check if the numNonZeros are same as number of Active values. Author: MechCoder Closes #6904 from MechCoder/nnz_matrix and squashes the following commits: 252c6b7 [MechCoder] Add to MiMa excludes e2390f5 [MechCoder] Use count instead of foreach 2f62b2f [MechCoder] Add to MiMa excludes d6e96ef [MechCoder] [SPARK-8479] Add numNonzeros and numActives to linalg.Matrices --- .../apache/spark/mllib/linalg/Matrices.scala | 19 +++++++++++++++++++ .../spark/mllib/linalg/MatricesSuite.scala | 10 ++++++++++ project/MimaExcludes.scala | 6 ++++++ 3 files changed, 35 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 0a615494bb2d1..75e7004464af9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -114,6 +114,16 @@ sealed trait Matrix extends Serializable { * corresponding value in the matrix with type `Double`. */ private[spark] def foreachActive(f: (Int, Int, Double) => Unit) + + /** + * Find the number of non-zero active values. + */ + def numNonzeros: Int + + /** + * Find the number of values stored explicitly. These values can be zero as well. + */ + def numActives: Int } @DeveloperApi @@ -324,6 +334,10 @@ class DenseMatrix( } } + override def numNonzeros: Int = values.count(_ != 0) + + override def numActives: Int = values.length + /** * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed * set to false. @@ -593,6 +607,11 @@ class SparseMatrix( def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } + + override def numNonzeros: Int = values.count(_ != 0) + + override def numActives: Int = values.length + } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 8dbb70f5d1c4c..a270ba2562db9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -455,4 +455,14 @@ class MatricesSuite extends SparkFunSuite { lines = mat.toString(5, 100).lines.toArray assert(lines.size == 5 && lines.forall(_.size <= 100)) } + + test("numNonzeros and numActives") { + val dm1 = Matrices.dense(3, 2, Array(0, 0, -1, 1, 0, 1)) + assert(dm1.numNonzeros === 3) + assert(dm1.numActives === 6) + + val sm1 = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0)) + assert(sm1.numNonzeros === 1) + assert(sm1.numActives === 3) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6f86a505b3ae4..680b699e9e4a1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -75,6 +75,12 @@ object MimaExcludes { "org.apache.spark.sql.parquet.ParquetTypeInfo"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.ParquetTypeInfo$") + ) ++ Seq( + // SPARK-8479 Add numNonzeros and numActives to Matrix. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numActives") ) case v if v.startsWith("1.4") => Seq( From 82cf3315e690f4ac15b50edea6a3d673aa5be4c0 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 2 Jul 2015 13:49:45 -0700 Subject: [PATCH 110/274] [SPARK-8781] Fix variables in published pom.xml are not resolved The issue is summarized in the JIRA and is caused by this commit: 984ad60147c933f2d5a2040c87ae687c14eb1724. This patch reverts that commit and fixes the maven build in a different way. We limit the dependencies of `KinesisReceiverSuite` to avoid having to deal with the complexities in how maven deals with transitive test dependencies. Author: Andrew Or Closes #7193 from andrewor14/fix-kinesis-pom and squashes the following commits: ca3d5d4 [Andrew Or] Limit kinesis test dependencies f24e09c [Andrew Or] Revert "[BUILD] Fix Maven build for Kinesis" --- extras/kinesis-asl/pom.xml | 7 ------- .../kinesis/KinesisReceiverSuite.scala | 20 +++++++++++-------- pom.xml | 2 -- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c242e7a57b9ab..5289073eb457a 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -40,13 +40,6 @@ spark-streaming_${scala.binary.version} ${project.version}
- - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - org.apache.spark spark-streaming_${scala.binary.version} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 2103dca6b766f..6c262624833cd 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -26,18 +26,23 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionIn import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, Matchers} +// scalastyle:off +// To avoid introducing a dependency on Spark core tests, simply use scalatest's FunSuite +// here instead of our own SparkFunSuite. Introducing the dependency has caused problems +// in the past (SPARK-8781) that are complicated by bugs in the maven shade plugin (MSHADE-148). +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.mock.MockitoSugar import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext} import org.apache.spark.util.{Clock, ManualClock, Utils} /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor */ -class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter - with MockitoSugar { +class KinesisReceiverSuite extends FunSuite with Matchers with BeforeAndAfter + with MockitoSugar { +// scalastyle:on val app = "TestKinesisReceiver" val stream = "mySparkStream" @@ -57,7 +62,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft var checkpointStateMock: KinesisCheckpointState = _ var currentClockMock: Clock = _ - override def beforeFunction(): Unit = { + before { receiverMock = mock[KinesisReceiver] checkpointerMock = mock[IRecordProcessorCheckpointer] checkpointClockMock = mock[ManualClock] @@ -65,8 +70,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft currentClockMock = mock[Clock] } - override def afterFunction(): Unit = { - super.afterFunction() + after { // Since this suite was originally written using EasyMock, add this to preserve the old // mocking semantics (see SPARK-5735 for more details) verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, @@ -74,7 +78,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft } test("KinesisUtils API") { - val ssc = new StreamingContext(master, framework, batchDuration) + val ssc = new StreamingContext("local[2]", getClass.getSimpleName, Seconds(1)) // Tests the API, does not actually test data receiving val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", "https://kinesis.us-west-2.amazonaws.com", Seconds(2), diff --git a/pom.xml b/pom.xml index 211da9ee74a3f..ffa96128a3d61 100644 --- a/pom.xml +++ b/pom.xml @@ -1440,8 +1440,6 @@ 2.3 false - - false From fcbcba66c92871fe3936e5ca605017e9c2a2eb95 Mon Sep 17 00:00:00 2001 From: Deron Eriksson Date: Thu, 2 Jul 2015 13:55:53 -0700 Subject: [PATCH 111/274] [SPARK-1564] [DOCS] Added Javascript to Javadocs to create badges for tags like :: Experimental :: Modified copy_api_dirs.rb and created api-javadocs.js and api-javadocs.css files in order to add badges to javadoc files for :: Experimental ::, :: DeveloperApi ::, and :: AlphaComponent :: tags Author: Deron Eriksson Closes #7169 from deroneriksson/SPARK-1564_JavaDocs_badges and squashes the following commits: a8353db [Deron Eriksson] added license headers to api-docs.css and api-javadocs.css 07feb07 [Deron Eriksson] added linebreaks to make jquery more readable when adding html badge tags 65b4930 [Deron Eriksson] Modified copy_api_dirs.rb and created api-javadocs.js and api-javadocs.css files in order to add badges to javadoc files for :: Experimental ::, :: DeveloperApi ::, and :: AlphaComponent :: tags --- docs/_plugins/copy_api_dirs.rb | 45 +++++++++++++++++++++++++ docs/css/api-docs.css | 17 ++++++++++ docs/css/api-javadocs.css | 52 +++++++++++++++++++++++++++++ docs/js/api-javadocs.js | 60 ++++++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+) create mode 100644 docs/css/api-javadocs.css create mode 100644 docs/js/api-javadocs.js diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 6073b3626c45b..15ceda11a8a80 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -63,6 +63,51 @@ puts "cp -r " + source + "/. " + dest cp_r(source + "/.", dest) + + # Begin updating JavaDoc files for badge post-processing + puts "Updating JavaDoc files for badge post-processing" + js_script_start = '' + + javadoc_files = Dir["./" + dest + "/**/*.html"] + javadoc_files.each do |javadoc_file| + # Determine file depths to reference js files + slash_count = javadoc_file.count "/" + i = 3 + path_to_js_file = "" + while (i < slash_count) do + path_to_js_file = path_to_js_file + "../" + i += 1 + end + + # Create script elements to reference js files + javadoc_jquery_script = js_script_start + path_to_js_file + "lib/jquery" + js_script_end; + javadoc_api_docs_script = js_script_start + path_to_js_file + "lib/api-javadocs" + js_script_end; + javadoc_script_elements = javadoc_jquery_script + javadoc_api_docs_script + + # Add script elements to JavaDoc files + javadoc_file_content = File.open(javadoc_file, "r") { |f| f.read } + javadoc_file_content = javadoc_file_content.sub("", javadoc_script_elements + "") + File.open(javadoc_file, "w") { |f| f.puts(javadoc_file_content) } + + end + # End updating JavaDoc files for badge post-processing + + puts "Copying jquery.js from Scala API to Java API for page post-processing of badges" + jquery_src_file = "./api/scala/lib/jquery.js" + jquery_dest_file = "./api/java/lib/jquery.js" + mkdir_p("./api/java/lib") + cp(jquery_src_file, jquery_dest_file) + + puts "Copying api_javadocs.js to Java API for page post-processing of badges" + api_javadocs_src_file = "./js/api-javadocs.js" + api_javadocs_dest_file = "./api/java/lib/api-javadocs.js" + cp(api_javadocs_src_file, api_javadocs_dest_file) + + puts "Appending content of api-javadocs.css to JavaDoc stylesheet.css for badge styles" + css = File.readlines("./css/api-javadocs.css") + css_file = dest + "/stylesheet.css" + File.open(css_file, 'a') { |f| f.write("\n" + css.join()) } end # Build Sphinx docs for Python diff --git a/docs/css/api-docs.css b/docs/css/api-docs.css index b2d1d7f869790..7cf222aad24f6 100644 --- a/docs/css/api-docs.css +++ b/docs/css/api-docs.css @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /* Dynamically injected style for the API docs */ .developer { diff --git a/docs/css/api-javadocs.css b/docs/css/api-javadocs.css new file mode 100644 index 0000000000000..832e92609e011 --- /dev/null +++ b/docs/css/api-javadocs.css @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Dynamically injected style for the API docs */ + +.badge { + font-family: Arial, san-serif; + float: right; + margin: 4px; + /* The following declarations are taken from the ScalaDoc template.css */ + display: inline-block; + padding: 2px 4px; + font-size: 11.844px; + font-weight: bold; + line-height: 14px; + color: #ffffff; + text-shadow: 0 -1px 0 rgba(0, 0, 0, 0.25); + white-space: nowrap; + vertical-align: baseline; + background-color: #999999; + padding-right: 9px; + padding-left: 9px; + -webkit-border-radius: 9px; + -moz-border-radius: 9px; + border-radius: 9px; +} + +.developer { + background-color: #44751E; +} + +.experimental { + background-color: #257080; +} + +.alphaComponent { + background-color: #bb0000; +} diff --git a/docs/js/api-javadocs.js b/docs/js/api-javadocs.js new file mode 100644 index 0000000000000..ead13d6e5fa7c --- /dev/null +++ b/docs/js/api-javadocs.js @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Dynamically injected post-processing code for the API docs */ + +$(document).ready(function() { + addBadges(":: AlphaComponent ::", 'Alpha Component'); + addBadges(":: DeveloperApi ::", 'Developer API'); + addBadges(":: Experimental ::", 'Experimental'); +}); + +function addBadges(tag, html) { + var tags = $(".block:contains(" + tag + ")") + + // Remove identifier tags + tags.each(function(index) { + var oldHTML = $(this).html(); + var newHTML = oldHTML.replace(tag, ""); + $(this).html(newHTML); + }); + + // Add html badge tags + tags.each(function(index) { + if ($(this).parent().is('td.colLast')) { + $(this).parent().prepend(html); + } else if ($(this).parent('li.blockList') + .parent('ul.blockList') + .parent('div.description') + .parent().is('div.contentContainer')) { + var contentContainer = $(this).parent('li.blockList') + .parent('ul.blockList') + .parent('div.description') + .parent('div.contentContainer') + var header = contentContainer.prev('div.header'); + if (header.length > 0) { + header.prepend(html); + } else { + contentContainer.prepend(html); + } + } else if ($(this).parent().is('li.blockList')) { + $(this).parent().prepend(html); + } else { + $(this).prepend(html); + } + }); +} From cd2035507891a7f426f6f45902d3b5f4fdbe88cf Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 2 Jul 2015 13:59:56 -0700 Subject: [PATCH 112/274] [SPARK-7835] Refactor HeartbeatReceiverSuite for coverage + cleanup The existing test suite has a lot of duplicate code and doesn't even cover the most fundamental feature of the HeartbeatReceiver, which is expiring hosts that have not responded in a while. This introduces manual clocks in `HeartbeatReceiver` and makes it respond to heartbeats only for registered executors. A few internal messages are moved to `receiveAndReply` to increase determinism of the tests so we don't have to rely on flaky constructs like `eventually`. Author: Andrew Or Closes #7173 from andrewor14/heartbeat-receiver-tests and squashes the following commits: 4a903d6 [Andrew Or] Increase HeartReceiverSuite coverage and clean up --- .../org/apache/spark/HeartbeatReceiver.scala | 89 +++++++--- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../apache/spark/HeartbeatReceiverSuite.scala | 161 +++++++++++++----- 3 files changed, 191 insertions(+), 61 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 6909015ff66e6..221b1dab43278 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId -import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -45,13 +45,23 @@ private[spark] case object TaskSchedulerIsSet private[spark] case object ExpireDeadHosts +private case class ExecutorRegistered(executorId: String) + +private case class ExecutorRemoved(executorId: String) + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext) - extends ThreadSafeRpcEndpoint with Logging { +private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) + extends ThreadSafeRpcEndpoint with SparkListener with Logging { + + def this(sc: SparkContext) { + this(sc, new SystemClock) + } + + sc.addSparkListener(this) override val rpcEnv: RpcEnv = sc.env.rpcEnv @@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def onStart(): Unit = { timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ExpireDeadHosts)) + Option(self).foreach(_.ask[Boolean](ExpireDeadHosts)) } }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - override def receive: PartialFunction[Any, Unit] = { - case ExpireDeadHosts => - expireDeadHosts() + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + + // Messages sent and received locally + case ExecutorRegistered(executorId) => + executorLastSeen(executorId) = clock.getTimeMillis() + context.reply(true) + case ExecutorRemoved(executorId) => + executorLastSeen.remove(executorId) + context.reply(true) case TaskSchedulerIsSet => scheduler = sc.taskScheduler - } + context.reply(true) + case ExpireDeadHosts => + expireDeadHosts() + context.reply(true) - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + // Messages received from executors case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - executorLastSeen(executorId) = System.currentTimeMillis() - eventLoopThread.submit(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - context.reply(response) - } - }) + if (executorLastSeen.contains(executorId)) { + executorLastSeen(executorId) = clock.getTimeMillis() + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) + } else { + // This may happen if we get an executor's in-flight heartbeat immediately + // after we just removed it. It's not really an error condition so we should + // not log warning here. Otherwise there may be a lot of noise especially if + // we explicitly remove executors (SPARK-4134). + logDebug(s"Received heartbeat from unknown executor $executorId") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -119,9 +147,30 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } + /** + * If the heartbeat receiver is not stopped, notify it of executor registrations. + */ + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) + } + + /** + * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't + * log superfluous errors. + * + * Note that we must do this after the executor is actually removed to guard against the + * following race condition: if we remove an executor's metadata from our data structure + * prematurely, we may get an in-flight heartbeat from the executor before the executor is + * actually removed, in which case we will still mark the executor as a dead host later + * and expire it with loud error messages. + */ + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + } + private def expireDeadHosts(): Unit = { logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.") - val now = System.currentTimeMillis() + val now = clock.getTimeMillis() for ((executorId, lastSeenMs) <- executorLastSeen) { if (now - lastSeenMs > executorTimeoutMs) { logWarning(s"Removing executor $executorId with no recent heartbeats: " + diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8eed46759f340..d2547eeff2b4e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -498,7 +498,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _schedulerBackend = sched _taskScheduler = ts _dagScheduler = new DAGScheduler(this) - _heartbeatReceiver.send(TaskSchedulerIsSet) + _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's // constructor diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 911b3bddd1836..b31b09196608f 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -17,64 +17,145 @@ package org.apache.spark -import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.storage.BlockManagerId +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ -import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.util.RpcUtils -import org.scalatest.concurrent.Eventually._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ManualClock -class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext { +class HeartbeatReceiverSuite + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester + with LocalSparkContext { - test("HeartbeatReceiver") { + private val executorId1 = "executor-1" + private val executorId2 = "executor-2" + + // Shared state that must be reset before and after each test + private var scheduler: TaskScheduler = null + private var heartbeatReceiver: HeartbeatReceiver = null + private var heartbeatReceiverRef: RpcEndpointRef = null + private var heartbeatReceiverClock: ManualClock = null + + override def beforeEach(): Unit = { sc = spy(new SparkContext("local[2]", "test")) - val scheduler = mock(classOf[TaskScheduler]) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + scheduler = mock(classOf[TaskScheduler]) when(sc.taskScheduler).thenReturn(scheduler) + heartbeatReceiverClock = new ManualClock + heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) + heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + } - val heartbeatReceiver = new HeartbeatReceiver(sc) - sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) - eventually(timeout(5 seconds), interval(5 millis)) { - assert(heartbeatReceiver.scheduler != null) - } - val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + override def afterEach(): Unit = { + resetSparkContext() + scheduler = null + heartbeatReceiver = null + heartbeatReceiverRef = null + heartbeatReceiverClock = null + } - val metrics = new TaskMetrics - val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + test("task scheduler is set correctly") { + assert(heartbeatReceiver.scheduler === null) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + assert(heartbeatReceiver.scheduler !== null) + } - verify(scheduler).executorHeartbeatReceived( - Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) - assert(false === response.reregisterBlockManager) + test("normal heartbeat") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + val trackedExecutors = executorLastSeen(heartbeatReceiver) + assert(trackedExecutors.size === 2) + assert(trackedExecutors.contains(executorId1)) + assert(trackedExecutors.contains(executorId2)) } - test("HeartbeatReceiver re-register") { - sc = spy(new SparkContext("local[2]", "test")) - val scheduler = mock(classOf[TaskScheduler]) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) - when(sc.taskScheduler).thenReturn(scheduler) + test("reregister if scheduler is not ready yet") { + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + // Task scheduler not set in HeartbeatReceiver + triggerHeartbeat(executorId1, executorShouldReregister = true) + } - val heartbeatReceiver = new HeartbeatReceiver(sc) - sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) - eventually(timeout(5 seconds), interval(5 millis)) { - assert(heartbeatReceiver.scheduler != null) - } - val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + test("reregister if heartbeat from unregistered executor") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + // Received heartbeat from unknown receiver, so we ask it to re-register + triggerHeartbeat(executorId1, executorShouldReregister = true) + assert(executorLastSeen(heartbeatReceiver).isEmpty) + } + + test("reregister if heartbeat from removed executor") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + // Remove the second executor but not the first + heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy")) + // Now trigger the heartbeats + // A heartbeat from the second executor should require reregistering + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = true) + val trackedExecutors = executorLastSeen(heartbeatReceiver) + assert(trackedExecutors.size === 1) + assert(trackedExecutors.contains(executorId1)) + assert(!trackedExecutors.contains(executorId2)) + } + test("expire dead hosts") { + val executorTimeout = executorTimeoutMs(heartbeatReceiver) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + // Advance the clock and only trigger a heartbeat for the first executor + heartbeatReceiverClock.advance(executorTimeout / 2) + triggerHeartbeat(executorId1, executorShouldReregister = false) + heartbeatReceiverClock.advance(executorTimeout) + heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) + // Only the second executor should be expired as a dead host + verify(scheduler).executorLost(Matchers.eq(executorId2), any()) + val trackedExecutors = executorLastSeen(heartbeatReceiver) + assert(trackedExecutors.size === 1) + assert(trackedExecutors.contains(executorId1)) + assert(!trackedExecutors.contains(executorId2)) + } + + /** Manually send a heartbeat and return the response. */ + private def triggerHeartbeat( + executorId: String, + executorShouldReregister: Boolean): Unit = { val metrics = new TaskMetrics - val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + val blockManagerId = BlockManagerId(executorId, "localhost", 12345) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + Heartbeat(executorId, Array(1L -> metrics), blockManagerId)) + if (executorShouldReregister) { + assert(response.reregisterBlockManager) + } else { + assert(!response.reregisterBlockManager) + // Additionally verify that the scheduler callback is called with the correct parameters + verify(scheduler).executorHeartbeatReceived( + Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + } + } - verify(scheduler).executorHeartbeatReceived( - Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) - assert(true === response.reregisterBlockManager) + // Helper methods to access private fields in HeartbeatReceiver + private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen) + private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs) + private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = { + receiver invokePrivate _executorLastSeen() + } + private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = { + receiver invokePrivate _executorTimeoutMs() } + } From 52508beb650a863ed5c89384414b3b7675cac11e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 2 Jul 2015 14:16:14 -0700 Subject: [PATCH 113/274] [SPARK-8772][SQL] Implement implicit type cast for expressions that define input types. Author: Reynold Xin Closes #7175 from rxin/implicitCast and squashes the following commits: 88080a2 [Reynold Xin] Clearer definition of implicit type cast. f0ff97f [Reynold Xin] Added missing file. c65e532 [Reynold Xin] [SPARK-8772][SQL] Implement implicit type cast for expressions that defines input types. --- .../catalyst/analysis/HiveTypeCoercion.scala | 41 ++++++- .../expressions/ExpectsInputTypes.scala | 24 +--- .../spark/sql/catalyst/expressions/math.scala | 7 +- .../spark/sql/catalyst/expressions/misc.scala | 12 +- .../sql/catalyst/expressions/predicates.scala | 14 +-- .../expressions/stringOperations.scala | 10 +- .../spark/sql/types/AbstractDataType.scala | 114 ++++++++++++++++++ .../apache/spark/sql/types/ArrayType.scala | 6 +- .../apache/spark/sql/types/BinaryType.scala | 2 - .../apache/spark/sql/types/BooleanType.scala | 2 - .../org/apache/spark/sql/types/ByteType.scala | 2 - .../org/apache/spark/sql/types/DataType.scala | 86 +------------ .../org/apache/spark/sql/types/DateType.scala | 5 +- .../apache/spark/sql/types/DecimalType.scala | 7 +- .../apache/spark/sql/types/DoubleType.scala | 2 - .../apache/spark/sql/types/FloatType.scala | 2 - .../apache/spark/sql/types/IntegerType.scala | 2 - .../org/apache/spark/sql/types/LongType.scala | 2 - .../org/apache/spark/sql/types/MapType.scala | 7 +- .../org/apache/spark/sql/types/NullType.scala | 2 - .../apache/spark/sql/types/ShortType.scala | 2 - .../apache/spark/sql/types/StringType.scala | 2 - .../apache/spark/sql/types/StructType.scala | 2 - .../spark/sql/types/TimestampType.scala | 2 - .../analysis/HiveTypeCoercionSuite.scala | 25 ++++ 25 files changed, 213 insertions(+), 169 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8420c54f7c335..0bc893224026e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -704,19 +704,48 @@ object HiveTypeCoercion { /** * Casts types according to the expected input types for Expressions that have the trait - * [[AutoCastInputTypes]]. + * [[ExpectsInputTypes]]. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes => - val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map { - case (child, actual, expected) => - if (actual == expected) child else Cast(child, expected) + case e: ExpectsInputTypes => + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + implicitCast(in, expected) } - e.withNewChildren(newC) + e.withNewChildren(children) + } + + /** + * If needed, cast the expression into the expected type. + * If the implicit cast is not allowed, return the expression itself. + */ + def implicitCast(e: Expression, expectedType: AbstractDataType): Expression = { + val inType = e.dataType + (inType, expectedType) match { + // Cast null type (usually from null literals) into target types + case (NullType, target: DataType) => Cast(e, target.defaultConcreteType) + + // Implicit cast among numeric types + case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) + + // Implicit cast between date time types + case (DateType, TimestampType) => Cast(e, TimestampType) + case (TimestampType, DateType) => Cast(e, DateType) + + // Implicit cast from/to string + case (StringType, NumericType) => Cast(e, DoubleType) + case (StringType, target: NumericType) => Cast(e, target) + case (StringType, DateType) => Cast(e, DateType) + case (StringType, TimestampType) => Cast(e, TimestampType) + case (StringType, BinaryType) => Cast(e, BinaryType) + case (any, StringType) if any != StringType => Cast(e, StringType) + + // Else, just return the same input expression + case _ => e + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 450fc4165f93b..916e30154d4f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.AbstractDataType /** @@ -32,28 +32,12 @@ trait ExpectsInputTypes { self: Expression => * * The possible values at each position are: * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + * 2. a non-leaf abstract data type, e.g. NumericType, IntegralType, FractionalType. */ - def inputTypes: Seq[Any] + def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess - } -} - -/** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. - */ -trait AutoCastInputTypes { self: Expression => - - def inputTypes: Seq[DataType] - - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. + // TODO: implement proper type checking. TypeCheckResult.TypeCheckSuccess } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7504c6a066657..035980da568d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -56,8 +56,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with AutoCastInputTypes { - self: Product => + extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -96,7 +95,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -208,7 +207,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia } case class Bin(child: Expression) - extends UnaryExpression with Serializable with AutoCastInputTypes { + extends UnaryExpression with Serializable with ExpectsInputTypes { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 407023e472081..e008af3966941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,8 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String * A function that calculates an MD5 128-bit checksum and returns it as a hex string * For input of type [[BinaryType]] */ -case class Md5(child: Expression) - extends UnaryExpression with AutoCastInputTypes { +case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = StringType @@ -62,12 +61,10 @@ case class Md5(child: Expression) * the hash length is not one of the permitted values, the return value is NULL. */ case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with AutoCastInputTypes { + extends BinaryExpression with Serializable with ExpectsInputTypes { override def dataType: DataType = StringType - override def toString: String = s"SHA2($left, $right)" - override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) override def eval(input: InternalRow): Any = { @@ -147,7 +144,7 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ -case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTypes { +case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = StringType @@ -174,8 +171,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp * A function that computes a cyclic redundancy check value and returns it as a bigint * For input of type [[BinaryType]] */ -case class Crc32(child: Expression) - extends UnaryExpression with AutoCastInputTypes { +case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d4569241e7364..0b479f466c63c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,7 +69,7 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } -case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { +case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { override def toString: String = s"NOT $child" override def inputTypes: Seq[DataType] = Seq(BooleanType) @@ -120,11 +120,11 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with ExpectsInputTypes { - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def toString: String = s"($left && $right)" - override def symbol: String = "&&" + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def eval(input: InternalRow): Any = { val l = left.eval(input) @@ -169,11 +169,11 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with ExpectsInputTypes { - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def toString: String = s"($left || $right)" - override def symbol: String = "||" + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def eval(input: InternalRow): Any = { val l = left.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index b020f2bbc5818..57918b32f8a47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends AutoCastInputTypes { +trait StringRegexExpression extends ExpectsInputTypes { self: BinaryExpression => def escape(v: String): String @@ -111,7 +111,7 @@ case class RLike(left: Expression, right: Expression) override def toString: String = s"$left RLIKE $right" } -trait CaseConversionExpression extends AutoCastInputTypes { +trait CaseConversionExpression extends ExpectsInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -154,7 +154,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends AutoCastInputTypes { +trait StringComparison extends ExpectsInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -215,7 +215,7 @@ case class EndsWith(left: Expression, right: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with AutoCastInputTypes { + extends Expression with ExpectsInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -283,7 +283,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string expression. */ -case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { +case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala new file mode 100644 index 0000000000000..43e2f8a46e62e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} + +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.util.Utils + +/** + * A non-concrete data type, reserved for internal uses. + */ +private[sql] abstract class AbstractDataType { + private[sql] def defaultConcreteType: DataType +} + + +/** + * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. + */ +protected[sql] abstract class AtomicType extends DataType { + private[sql] type InternalType + @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val ordering: Ordering[InternalType] + + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) + } +} + + +/** + * :: DeveloperApi :: + * Numeric data types. + */ +abstract class NumericType extends AtomicType { + // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for + // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a + // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // desugared by the compiler into an argument to the objects constructor. This means there is no + // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + private[sql] val numeric: Numeric[InternalType] +} + + +private[sql] object NumericType extends AbstractDataType { + /** + * Enables matching against NumericType for expressions: + * {{{ + * case Cast(child @ NumericType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] + + private[sql] override def defaultConcreteType: DataType = IntegerType +} + + +private[sql] object IntegralType extends AbstractDataType { + /** + * Enables matching against IntegralType for expressions: + * {{{ + * case Cast(child @ IntegralType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] + + private[sql] override def defaultConcreteType: DataType = IntegerType +} + + +private[sql] abstract class IntegralType extends NumericType { + private[sql] val integral: Integral[InternalType] +} + + +private[sql] object FractionalType extends AbstractDataType { + /** + * Enables matching against FractionalType for expressions: + * {{{ + * case Cast(child @ FractionalType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType] + + private[sql] override def defaultConcreteType: DataType = DoubleType +} + + +private[sql] abstract class FractionalType extends NumericType { + private[sql] val fractional: Fractional[InternalType] + private[sql] val asIntegral: Integral[InternalType] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index b116163faccad..81553e7fc91a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -22,9 +22,11 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi -object ArrayType { +object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) + + override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) } @@ -41,8 +43,6 @@ object ArrayType { * * @param elementType The data type of values. * @param containsNull Indicates if values have `null` values - * - * @group dataType */ @DeveloperApi case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index 9b58601e5e6ec..f2c6f34ea51c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -29,8 +29,6 @@ import org.apache.spark.sql.catalyst.util.TypeUtils * :: DeveloperApi :: * The data type representing `Array[Byte]` values. * Please use the singleton [[DataTypes.BinaryType]]. - * - * @group dataType */ @DeveloperApi class BinaryType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index a7f228cefa57a..2d8ee3d9bc286 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. - * - *@group dataType */ @DeveloperApi class BooleanType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 4d8685796ec76..2ca427975a1cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. - * - * @group dataType */ @DeveloperApi class ByteType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 74677ddfcad65..c333fa70d1ef4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers import org.json4s._ @@ -27,19 +25,15 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.util.Utils /** * :: DeveloperApi :: * The base type of all Spark SQL data types. - * - * @group dataType */ @DeveloperApi -abstract class DataType { +abstract class DataType extends AbstractDataType { /** * Enables matching against DataType for expressions: * {{{ @@ -80,84 +74,8 @@ abstract class DataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def asNullable: DataType -} - - -/** - * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. - */ -protected[sql] abstract class AtomicType extends DataType { - private[sql] type InternalType - @transient private[sql] val tag: TypeTag[InternalType] - private[sql] val ordering: Ordering[InternalType] - - @transient private[sql] val classTag = ScalaReflectionLock.synchronized { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) - } -} - - -/** - * :: DeveloperApi :: - * Numeric data types. - * - * @group dataType - */ -abstract class NumericType extends AtomicType { - // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for - // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets - // desugared by the compiler into an argument to the objects constructor. This means there is no - // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - private[sql] val numeric: Numeric[InternalType] -} - - -private[sql] object NumericType { - /** - * Enables matching against NumericType for expressions: - * {{{ - * case Cast(child @ NumericType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] -} - - -private[sql] object IntegralType { - /** - * Enables matching against IntegralType for expressions: - * {{{ - * case Cast(child @ IntegralType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] -} - - -private[sql] abstract class IntegralType extends NumericType { - private[sql] val integral: Integral[InternalType] -} - - -private[sql] object FractionalType { - /** - * Enables matching against FractionalType for expressions: - * {{{ - * case Cast(child @ FractionalType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType] -} - -private[sql] abstract class FractionalType extends NumericType { - private[sql] val fractional: Fractional[InternalType] - private[sql] val asIntegral: Integral[InternalType] + override def defaultConcreteType: DataType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 03f0644bc784c..1d73e40ffcd36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -26,10 +26,11 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: - * The data type representing `java.sql.Date` values. + * A date type, supporting "0001-01-01" through "9999-12-31". + * * Please use the singleton [[DataTypes.DateType]]. * - * @group dataType + * Internally, this is represented as the number of days from epoch (1970-01-01 00:00:00 UTC). */ @DeveloperApi class DateType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 18cdfa7238f39..06373a095b1b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -39,8 +39,6 @@ case class PrecisionInfo(precision: Int, scale: Int) { * A Decimal that might have fixed precision and scale, or unlimited values for these. * * Please use [[DataTypes.createDecimalType()]] to create a specific instance. - * - * @group dataType */ @DeveloperApi case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { @@ -84,7 +82,10 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** Extra factory methods and pattern matchers for Decimals */ -object DecimalType { +object DecimalType extends AbstractDataType { + + private[sql] override def defaultConcreteType: DataType = Unlimited + val Unlimited: DecimalType = DecimalType(None) private[sql] object Fixed { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 66766623213c9..986c2ab055386 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. - * - * @group dataType */ @DeveloperApi class DoubleType private() extends FractionalType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 1d5a2f4f6f86c..9bd48ece83a1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. - * - * @group dataType */ @DeveloperApi class FloatType private() extends FractionalType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index 74e464c082873..a2c6e19b05b3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. - * - * @group dataType */ @DeveloperApi class IntegerType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 390675782e5fd..2b3adf6ade83b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -26,8 +26,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. - * - * @group dataType */ @DeveloperApi class LongType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index cfdf493074415..69c2119e23436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -30,8 +30,6 @@ import org.json4s.JsonDSL._ * @param keyType The data type of map keys. * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. - * - * @group dataType */ case class MapType( keyType: DataType, @@ -69,7 +67,10 @@ case class MapType( } -object MapType { +object MapType extends AbstractDataType { + + private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType) + /** * Construct a [[MapType]] object with the given key type and value type. * The `valueContainsNull` is true. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index b64b07431fa96..aa84115c2e42c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -23,8 +23,6 @@ import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. - * - * @group dataType */ @DeveloperApi class NullType private() extends DataType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index 73e9ec780b0af..a13119e659064 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -26,8 +26,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. - * - * @group dataType */ @DeveloperApi class ShortType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 1e9476ad06656..a7627a2de1611 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -27,8 +27,6 @@ import org.apache.spark.unsafe.types.UTF8String /** * :: DeveloperApi :: * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. - * - * @group dataType */ @DeveloperApi class StringType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 2db0a359e9db5..6fedeabf23203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -87,8 +87,6 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} * val row = Row(Row(1, 2, true)) * // row: Row = [[1,2,true]] * }}} - * - * @group dataType */ @DeveloperApi case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index a558641fcfed7..de4b511edccd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock * :: DeveloperApi :: * The data type representing `java.sql.Timestamp` values. * Please use the singleton [[DataTypes.TimestampType]]. - * - * @group dataType */ @DeveloperApi class TimestampType private() extends AtomicType { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index eae3666595a38..498fd86a06fd9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -26,6 +26,31 @@ import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { + test("implicit type cast") { + def shouldCast(from: DataType, to: AbstractDataType): Unit = { + val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) + assert(got.dataType === to.defaultConcreteType) + } + + // TODO: write the entire implicit cast table out for test cases. + shouldCast(ByteType, IntegerType) + shouldCast(IntegerType, IntegerType) + shouldCast(IntegerType, LongType) + shouldCast(IntegerType, DecimalType.Unlimited) + shouldCast(LongType, IntegerType) + shouldCast(LongType, DecimalType.Unlimited) + + shouldCast(DateType, TimestampType) + shouldCast(TimestampType, DateType) + + shouldCast(StringType, IntegerType) + shouldCast(StringType, DateType) + shouldCast(StringType, TimestampType) + shouldCast(IntegerType, StringType) + shouldCast(DateType, StringType) + shouldCast(TimestampType, StringType) + } + test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) From 7d9cc9673e47227f58411ca1f4e647cd8233a219 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 2 Jul 2015 15:00:13 -0700 Subject: [PATCH 114/274] [SPARK-3382] [MLLIB] GradientDescent convergence tolerance GrandientDescent can receive convergence tolerance value. Default value is 0.0. When loss value becomes less than the tolerance which is set by user, iteration is terminated. Author: lewuathe Closes #3636 from Lewuathe/gd-convergence-tolerance and squashes the following commits: 0b8a9a8 [lewuathe] Update doc ce91b15 [lewuathe] Merge branch 'master' into gd-convergence-tolerance 4f22c2b [lewuathe] Modify based on SPARK-1503 5e47b82 [lewuathe] Merge branch 'master' into gd-convergence-tolerance abadb7e [lewuathe] Fix LassoSuite 8fadebd [lewuathe] Fix failed unit tests ee5de46 [lewuathe] Merge branch 'master' into gd-convergence-tolerance 8313ba2 [lewuathe] Fix styles 0ead94c [lewuathe] Merge branch 'master' into gd-convergence-tolerance a94cfd5 [lewuathe] Modify some styles 3aef0a2 [lewuathe] Modify converged logic to do relative comparison f7b19d5 [lewuathe] [SPARK-3382] Clarify comparison logic e6c9cd2 [lewuathe] [SPARK-3382] Compare with the diff of solution vector 4b125d2 [lewuathe] [SPARK3382] Fix scala style e7c10dd [lewuathe] [SPARK-3382] format improvements f867eea [lewuathe] [SPARK-3382] Modify warning message statements b9d5e61 [lewuathe] [SPARK-3382] should compare diff inside loss history and convergence tolerance 5433f71 [lewuathe] [SPARK-3382] GradientDescent convergence tolerance --- .../mllib/optimization/GradientDescent.scala | 105 +++++++++++++++--- .../StreamingLinearRegressionWithSGD.scala | 6 + .../LogisticRegressionSuite.scala | 1 + .../optimization/GradientDescentSuite.scala | 45 +++++++- .../spark/mllib/optimization/LBFGSSuite.scala | 6 +- .../spark/mllib/regression/LassoSuite.scala | 2 +- .../StreamingLinearRegressionSuite.scala | 1 + 7 files changed, 144 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 06e45e10c5bf4..ab7611fd077ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -19,13 +19,14 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV} +import breeze.linalg.{DenseVector => BDV, norm} import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} + /** * Class used to solve an optimization problem using Gradient Descent. * @param gradient Gradient function to be used. @@ -38,6 +39,7 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va private var numIterations: Int = 100 private var regParam: Double = 0.0 private var miniBatchFraction: Double = 1.0 + private var convergenceTol: Double = 0.001 /** * Set the initial step size of SGD for the first step. Default 1.0. @@ -75,6 +77,23 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va this } + /** + * Set the convergence tolerance. Default 0.001 + * convergenceTol is a condition which decides iteration termination. + * The end of iteration is decided based on below logic. + * - If the norm of the new solution vector is >1, the diff of solution vectors + * is compared to relative tolerance which means normalizing by the norm of + * the new solution vector. + * - If the norm of the new solution vector is <=1, the diff of solution vectors + * is compared to absolute tolerance which is not normalizing. + * Must be between 0.0 and 1.0 inclusively. + */ + def setConvergenceTol(tolerance: Double): this.type = { + require(0.0 <= tolerance && tolerance <= 1.0) + this.convergenceTol = tolerance + this + } + /** * Set the gradient function (of the loss function of one single data example) * to be used for SGD. @@ -112,7 +131,8 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va numIterations, regParam, miniBatchFraction, - initialWeights) + initialWeights, + convergenceTol) weights } @@ -131,17 +151,20 @@ object GradientDescent extends Logging { * Sampling, and averaging the subgradients over this subset is performed using one standard * spark map-reduce in each iteration. * - * @param data - Input data for SGD. RDD of the set of data examples, each of - * the form (label, [feature values]). - * @param gradient - Gradient object (used to compute the gradient of the loss function of - * one single data example) - * @param updater - Updater function to actually perform a gradient step in a given direction. - * @param stepSize - initial step size for the first step - * @param numIterations - number of iterations that SGD should be run. - * @param regParam - regularization parameter - * @param miniBatchFraction - fraction of the input data set that should be used for - * one iteration of SGD. Default value 1.0. - * + * @param data Input data for SGD. RDD of the set of data examples, each of + * the form (label, [feature values]). + * @param gradient Gradient object (used to compute the gradient of the loss function of + * one single data example) + * @param updater Updater function to actually perform a gradient step in a given direction. + * @param stepSize initial step size for the first step + * @param numIterations number of iterations that SGD should be run. + * @param regParam regularization parameter + * @param miniBatchFraction fraction of the input data set that should be used for + * one iteration of SGD. Default value 1.0. + * @param convergenceTol Minibatch iteration will end before numIterations if the relative + * difference between the current weight and the previous weight is less + * than this value. In measuring convergence, L2 norm is calculated. + * Default value 0.001. Must be between 0.0 and 1.0 inclusively. * @return A tuple containing two elements. The first element is a column matrix containing * weights for every feature, and the second element is an array containing the * stochastic loss computed for every iteration. @@ -154,9 +177,20 @@ object GradientDescent extends Logging { numIterations: Int, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): (Vector, Array[Double]) = { + initialWeights: Vector, + convergenceTol: Double): (Vector, Array[Double]) = { + + // convergenceTol should be set with non minibatch settings + if (miniBatchFraction < 1.0 && convergenceTol > 0.0) { + logWarning("Testing against a convergenceTol when using miniBatchFraction " + + "< 1.0 can be unstable because of the stochasticity in sampling.") + } val stochasticLossHistory = new ArrayBuffer[Double](numIterations) + // Record previous weight and current one to calculate solution vector difference + + var previousWeights: Option[Vector] = None + var currentWeights: Option[Vector] = None val numExamples = data.count() @@ -181,7 +215,9 @@ object GradientDescent extends Logging { var regVal = updater.compute( weights, Vectors.zeros(weights.size), 0, 1, regParam)._2 - for (i <- 1 to numIterations) { + var converged = false // indicates whether converged based on convergenceTol + var i = 1 + while (!converged && i <= numIterations) { val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) @@ -204,12 +240,21 @@ object GradientDescent extends Logging { */ stochasticLossHistory.append(lossSum / miniBatchSize + regVal) val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam) + weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), + stepSize, i, regParam) weights = update._1 regVal = update._2 + + previousWeights = currentWeights + currentWeights = Some(weights) + if (previousWeights != None && currentWeights != None) { + converged = isConverged(previousWeights.get, + currentWeights.get, convergenceTol) + } } else { logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero") } + i += 1 } logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( @@ -218,4 +263,32 @@ object GradientDescent extends Logging { (weights, stochasticLossHistory.toArray) } + + def runMiniBatchSGD( + data: RDD[(Double, Vector)], + gradient: Gradient, + updater: Updater, + stepSize: Double, + numIterations: Int, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Vector): (Vector, Array[Double]) = + GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations, + regParam, miniBatchFraction, initialWeights, 0.001) + + + private def isConverged( + previousWeights: Vector, + currentWeights: Vector, + convergenceTol: Double): Boolean = { + // To compare with convergence tolerance. + val previousBDV = previousWeights.toBreeze.toDenseVector + val currentBDV = currentWeights.toBreeze.toDenseVector + + // This represents the difference of updated weights in the iteration. + val solutionVecDiff: Double = norm(previousBDV - currentBDV) + + solutionVecDiff < convergenceTol * Math.max(norm(currentBDV), 1.0) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 235e043c7754b..c6d04464a12ba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -85,4 +85,10 @@ class StreamingLinearRegressionWithSGD private[mllib] ( this } + /** Set the convergence tolerance. */ + def setConvergenceTol(tolerance: Double): this.type = { + this.algorithm.optimizer.setConvergenceTol(tolerance) + this + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index e8f3d0c4db20a..2473510e13514 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -196,6 +196,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w .setStepSize(10.0) .setRegParam(0.0) .setNumIterations(20) + .setConvergenceTol(0.0005) val model = lr.run(testRDD) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index a5a59e9fad5ae..13b754a03943a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -82,11 +82,11 @@ class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Vectors.dense(1.0 +: features.toArray) + label -> MLUtils.appendBias(features) } val dataRDD = sc.parallelize(data, 2).cache() - val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray) + val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0) val (_, loss) = GradientDescent.runMiniBatchSGD( dataRDD, @@ -139,6 +139,45 @@ class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with "The different between newWeights with/without regularization " + "should be initialWeightsWithIntercept.") } + + test("iteration should end with convergence tolerance") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val initialB = -1.0 + val initialWeights = Array(initialB) + + val gradient = new LogisticGradient() + val updater = new SimpleUpdater() + val stepSize = 1.0 + val numIterations = 10 + val regParam = 0 + val miniBatchFrac = 1.0 + val convergenceTolerance = 5.0e-1 + + // Add a extra variable consisting of all 1.0's for the intercept. + val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) + val data = testData.map { case LabeledPoint(label, features) => + label -> MLUtils.appendBias(features) + } + + val dataRDD = sc.parallelize(data, 2).cache() + val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0) + + val (_, loss) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + updater, + stepSize, + numIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept, + convergenceTolerance) + + assert(loss.length < numIterations, "convergenceTolerance failed to stop optimization early") + } } class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index d07b9d5b89227..75ae0eb32fb7b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -122,7 +122,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers numGDIterations, regParam, miniBatchFrac, - initialWeightsWithIntercept) + initialWeightsWithIntercept, + convergenceTol) assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5, "The first losses of LBFGS and GD should be the same.") @@ -221,7 +222,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers numGDIterations, regParam, miniBatchFrac, - initialWeightsWithIntercept) + initialWeightsWithIntercept, + convergenceTol) // for class LBFGS and the optimize method, we only look at the weights assert( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 08a152ffc7a23..39537e7bb4c72 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -100,7 +100,7 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { val testRDD = sc.parallelize(testData, 2).cache() val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40) + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40).setConvergenceTol(0.0005) val model = ls.run(testRDD, initialWeights) val weight0 = model.weights(0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index f5e2d31056cbd..a2a4c5f6b8b70 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -53,6 +53,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { .setInitialWeights(Vectors.dense(0.0, 0.0)) .setStepSize(0.2) .setNumIterations(25) + .setConvergenceTol(0.0001) // generate sequence of simulated data val numBatches = 10 From fc7aebd94a3c09657fc4dbded0997ed068304e0a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Jul 2015 15:43:02 -0700 Subject: [PATCH 115/274] [SPARK-8784] [SQL] Add Python API for hex and unhex Also improve the performance of hex/unhex Author: Davies Liu Closes #7181 from davies/hex and squashes the following commits: f032fbb [Davies Liu] Merge branch 'hex' of github.com:davies/spark into hex 49e325f [Davies Liu] Merge branch 'master' of github.com:apache/spark into hex b31fc9a [Davies Liu] Update math.scala 25156b7 [Davies Liu] address comments and fix test c3af78c [Davies Liu] address commments 1a24082 [Davies Liu] Add Python API for hex and unhex --- python/pyspark/sql/functions.py | 28 ++++ .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 142 +++++++++--------- .../expressions/MathFunctionsSuite.scala | 18 ++- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 115 insertions(+), 77 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 12263e6a75af8..8a470ce19bc30 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -381,6 +381,34 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def hex(col): + """Computes hex value of the given column, which could be StringType, + BinaryType, IntegerType or LongType. + + >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + [Row(hex(a)=u'414243', hex(b)=u'3')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def unhex(col): + """Inverse of hex. Interprets each pair of characters as a hexadecimal number + and converts to the byte representation of number. + + >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + [Row(unhex(a)=bytearray(b'ABC'))] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.unhex(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def sha1(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e7e4d1c4efe18..ca87bcc4c4aab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -160,7 +160,7 @@ object FunctionRegistry { expression[Substring]("substr"), expression[Substring]("substring"), expression[Upper]("ucase"), - expression[UnHex]("unhex"), + expression[Unhex]("unhex"), expression[Upper]("upper") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 035980da568d3..1e095149f1166 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -227,6 +227,20 @@ case class Bin(child: Expression) } } +object Hex { + val hexDigits = Array[Char]( + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + ).map(_.toByte) + + // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 + val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } +} /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. @@ -258,30 +272,18 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { case LongType => hex(num.asInstanceOf[Long]) case IntegerType => hex(num.asInstanceOf[Integer].toLong) case BinaryType => hex(num.asInstanceOf[Array[Byte]]) - case StringType => hex(num.asInstanceOf[UTF8String]) + case StringType => hex(num.asInstanceOf[UTF8String].getBytes) } } } - /** - * Converts every character in s to two hex digits. - */ - private def hex(str: UTF8String): UTF8String = { - hex(str.getBytes) - } - - private def hex(bytes: Array[Byte]): UTF8String = { - doHex(bytes, bytes.length) - } - - private def doHex(bytes: Array[Byte], length: Int): UTF8String = { + private[this] def hex(bytes: Array[Byte]): UTF8String = { + val length = bytes.length val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = Character.toUpperCase(Character.forDigit( - (bytes(i) & 0xF0) >>> 4, 16)).toByte - value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( - bytes(i) & 0x0F, 16)).toByte + value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) i += 1 } UTF8String.fromBytes(value) @@ -294,14 +296,64 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { var len = 0 do { len += 1 - value(value.length - len) = Character.toUpperCase(Character - .forDigit((numBuf & 0xF).toInt, 16)).toByte + value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) } } +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class Unhex(child: Expression) + extends UnaryExpression with ExpectsInputTypes with Serializable { + + override def nullable: Boolean = true + override def dataType: DataType = BinaryType + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + unhex(num.asInstanceOf[UTF8String].getBytes) + } + } + + private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { + val out = new Array[Byte]((bytes.length + 1) >> 1) + var i = 0 + if ((bytes.length & 0x01) != 0) { + // padding with '0' + if (bytes(0) < 0) { + return null + } + val v = Hex.unhexDigits(bytes(0)) + if (v == -1) { + return null + } + out(0) = v + i += 1 + } + // two characters form the hex value. + while (i < bytes.length) { + if (bytes(i) < 0 || bytes(i + 1) < 0) { + return null + } + val first = Hex.unhexDigits(bytes(i)) + val second = Hex.unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { + return null + } + out(i / 2) = (((first << 4) | second) & 0xFF).toByte + i += 2 + } + out + } +} //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -448,58 +500,6 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress } } -/** - * Performs the inverse operation of HEX. - * Resulting characters are returned as a byte array. - */ -case class UnHex(child: Expression) extends UnaryExpression with Serializable { - - override def dataType: DataType = BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}") - } - } - - override def eval(input: InternalRow): Any = { - val num = child.eval(input) - if (num == null) { - null - } else { - unhex(num.asInstanceOf[UTF8String].getBytes) - } - } - - private val unhexDigits = { - val array = Array.fill[Byte](128)(-1) - (0 to 9).foreach(i => array('0' + i) = i.toByte) - (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) - (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) - array - } - - private def unhex(inputBytes: Array[Byte]): Array[Byte] = { - var bytes = inputBytes - if ((bytes.length & 0x01) != 0) { - bytes = '0'.toByte +: bytes - } - val out = new Array[Byte](bytes.length >> 1) - // two characters form the hex value. - var i = 0 - while (i < bytes.length) { - val first = unhexDigits(bytes(i)) - val second = unhexDigits(bytes(i + 1)) - if (first == -1 || second == -1) { return null} - out(i / 2) = (((first << 4) | second) & 0xFF).toByte - i += 2 - } - out - } -} - case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index aa27fe3cd5564..550c6e3cc9f0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType} +import org.apache.spark.sql.types._ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -252,11 +252,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("hex") { + checkEvaluation(Hex(Literal.create(null, IntegerType)), null) checkEvaluation(Hex(Literal(28)), "1C") checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") + checkEvaluation(Hex(Literal.create(null, LongType)), null) checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") + checkEvaluation(Hex(Literal.create(null, StringType)), null) checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") + checkEvaluation(Hex(Literal.create(null, BinaryType)), null) checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") // scalastyle:off // Turn off scala style for non-ascii chars @@ -265,9 +269,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("unhex") { - checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) - checkEvaluation(UnHex(Literal("")), new Array[Byte](0)) - checkEvaluation(UnHex(Literal("0")), Array[Byte](0)) + checkEvaluation(Unhex(Literal.create(null, StringType)), null) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) + checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) + checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) + // scalastyle:on } test("hypot") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4ee1fb8374b07..4b1353fc32c35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1061,7 +1061,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = UnHex(column.expr) + def unhex(column: Column): Column = Unhex(column.expr) /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number From 488bad319a70975733e83c83490240a70beb0c90 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 2 Jul 2015 15:55:16 -0700 Subject: [PATCH 116/274] [SPARK-7104] [MLLIB] Support model save/load in Python's Word2Vec Author: Yu ISHIKAWA Closes #6821 from yu-iskw/SPARK-7104 and squashes the following commits: 975136b [Yu ISHIKAWA] Organize import 0ef58b6 [Yu ISHIKAWA] Use rmtree, instead of removedirs cb21653 [Yu ISHIKAWA] Add an explicit type for `Word2VecModelWrapper.save` 1d468ef [Yu ISHIKAWA] [SPARK-7104][MLlib] Support model save/load in Python's Word2Vec --- .../mllib/api/python/PythonMLLibAPI.scala | 3 +++ python/pyspark/mllib/feature.py | 21 ++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 458fab48fef5a..e628059c4af8e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -28,6 +28,7 @@ import scala.reflect.ClassTag import net.razorvine.pickle._ +import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ @@ -641,6 +642,8 @@ private[python] class PythonMLLibAPI extends Serializable { def getVectors: JMap[String, JList[Float]] = { model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) } /** diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index b5138773fd61b..f921e3ad1a314 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -36,6 +36,7 @@ from pyspark.mllib.linalg import ( Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', @@ -416,7 +417,7 @@ def fit(self, dataset): return IDFModel(jmodel) -class Word2VecModel(JavaVectorTransformer): +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model """ @@ -455,6 +456,12 @@ def getVectors(self): """ return self.call("getVectors") + @classmethod + def load(cls, sc, path): + jmodel = sc._jvm.org.apache.spark.mllib.feature \ + .Word2VecModel.load(sc._jsc.sc(), path) + return Word2VecModel(jmodel) + @ignore_unicode_prefix class Word2Vec(object): @@ -488,6 +495,18 @@ class Word2Vec(object): >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] [u'b', u'c'] + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = Word2VecModel.load(sc, path) + >>> model.transform("a") == sameModel.transform("a") + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def __init__(self): """ From e589e71a2914588985eaea799b52e2f6b4f1e9ae Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 2 Jul 2015 16:25:10 -0700 Subject: [PATCH 117/274] Revert "[SPARK-8784] [SQL] Add Python API for hex and unhex" This reverts commit fc7aebd94a3c09657fc4dbded0997ed068304e0a. --- python/pyspark/sql/functions.py | 28 ---- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 142 +++++++++--------- .../expressions/MathFunctionsSuite.scala | 18 +-- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 77 insertions(+), 115 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8a470ce19bc30..12263e6a75af8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -381,34 +381,6 @@ def randn(seed=None): return Column(jc) -@ignore_unicode_prefix -@since(1.5) -def hex(col): - """Computes hex value of the given column, which could be StringType, - BinaryType, IntegerType or LongType. - - >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() - [Row(hex(a)=u'414243', hex(b)=u'3')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.hex(_to_java_column(col)) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def unhex(col): - """Inverse of hex. Interprets each pair of characters as a hexadecimal number - and converts to the byte representation of number. - - >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() - [Row(unhex(a)=bytearray(b'ABC'))] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.unhex(_to_java_column(col)) - return Column(jc) - - @ignore_unicode_prefix @since(1.5) def sha1(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ca87bcc4c4aab..e7e4d1c4efe18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -160,7 +160,7 @@ object FunctionRegistry { expression[Substring]("substr"), expression[Substring]("substring"), expression[Upper]("ucase"), - expression[Unhex]("unhex"), + expression[UnHex]("unhex"), expression[Upper]("upper") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 1e095149f1166..035980da568d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -227,20 +227,6 @@ case class Bin(child: Expression) } } -object Hex { - val hexDigits = Array[Char]( - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' - ).map(_.toByte) - - // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 - val unhexDigits = { - val array = Array.fill[Byte](128)(-1) - (0 to 9).foreach(i => array('0' + i) = i.toByte) - (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) - (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) - array - } -} /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. @@ -272,18 +258,30 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { case LongType => hex(num.asInstanceOf[Long]) case IntegerType => hex(num.asInstanceOf[Integer].toLong) case BinaryType => hex(num.asInstanceOf[Array[Byte]]) - case StringType => hex(num.asInstanceOf[UTF8String].getBytes) + case StringType => hex(num.asInstanceOf[UTF8String]) } } } - private[this] def hex(bytes: Array[Byte]): UTF8String = { - val length = bytes.length + /** + * Converts every character in s to two hex digits. + */ + private def hex(str: UTF8String): UTF8String = { + hex(str.getBytes) + } + + private def hex(bytes: Array[Byte]): UTF8String = { + doHex(bytes, bytes.length) + } + + private def doHex(bytes: Array[Byte], length: Int): UTF8String = { val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) - value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) + value(i * 2) = Character.toUpperCase(Character.forDigit( + (bytes(i) & 0xF0) >>> 4, 16)).toByte + value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( + bytes(i) & 0x0F, 16)).toByte i += 1 } UTF8String.fromBytes(value) @@ -296,64 +294,14 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { var len = 0 do { len += 1 - value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) + value(value.length - len) = Character.toUpperCase(Character + .forDigit((numBuf & 0xF).toInt, 16)).toByte numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) } } -/** - * Performs the inverse operation of HEX. - * Resulting characters are returned as a byte array. - */ -case class Unhex(child: Expression) - extends UnaryExpression with ExpectsInputTypes with Serializable { - - override def nullable: Boolean = true - override def dataType: DataType = BinaryType - override def inputTypes: Seq[DataType] = Seq(BinaryType) - - override def eval(input: InternalRow): Any = { - val num = child.eval(input) - if (num == null) { - null - } else { - unhex(num.asInstanceOf[UTF8String].getBytes) - } - } - - private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { - val out = new Array[Byte]((bytes.length + 1) >> 1) - var i = 0 - if ((bytes.length & 0x01) != 0) { - // padding with '0' - if (bytes(0) < 0) { - return null - } - val v = Hex.unhexDigits(bytes(0)) - if (v == -1) { - return null - } - out(0) = v - i += 1 - } - // two characters form the hex value. - while (i < bytes.length) { - if (bytes(i) < 0 || bytes(i + 1) < 0) { - return null - } - val first = Hex.unhexDigits(bytes(i)) - val second = Hex.unhexDigits(bytes(i + 1)) - if (first == -1 || second == -1) { - return null - } - out(i / 2) = (((first << 4) | second) & 0xFF).toByte - i += 2 - } - out - } -} //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -500,6 +448,58 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress } } +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class UnHex(child: Expression) extends UnaryExpression with Serializable { + + override def dataType: DataType = BinaryType + + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}") + } + } + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + unhex(num.asInstanceOf[UTF8String].getBytes) + } + } + + private val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } + + private def unhex(inputBytes: Array[Byte]): Array[Byte] = { + var bytes = inputBytes + if ((bytes.length & 0x01) != 0) { + bytes = '0'.toByte +: bytes + } + val out = new Array[Byte](bytes.length >> 1) + // two characters form the hex value. + var i = 0 + while (i < bytes.length) { + val first = unhexDigits(bytes(i)) + val second = unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { return null} + out(i / 2) = (((first << 4) | second) & 0xFF).toByte + i += 2 + } + out + } +} + case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 550c6e3cc9f0b..aa27fe3cd5564 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType} class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -252,15 +252,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("hex") { - checkEvaluation(Hex(Literal.create(null, IntegerType)), null) checkEvaluation(Hex(Literal(28)), "1C") checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") - checkEvaluation(Hex(Literal.create(null, LongType)), null) checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") - checkEvaluation(Hex(Literal.create(null, StringType)), null) checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") - checkEvaluation(Hex(Literal.create(null, BinaryType)), null) checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") // scalastyle:off // Turn off scala style for non-ascii chars @@ -269,15 +265,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("unhex") { - checkEvaluation(Unhex(Literal.create(null, StringType)), null) - checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) - checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) - checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) - checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) - // scalastyle:off - // Turn off scala style for non-ascii chars - checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) - // scalastyle:on + checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(UnHex(Literal("")), new Array[Byte](0)) + checkEvaluation(UnHex(Literal("0")), Array[Byte](0)) } test("hypot") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4b1353fc32c35..4ee1fb8374b07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1061,7 +1061,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = Unhex(column.expr) + def unhex(column: Column): Column = UnHex(column.expr) /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number From d9838196ff48faeac19756852a7f695129c08047 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 18:07:09 -0700 Subject: [PATCH 118/274] [SPARK-8782] [SQL] Fix code generation for ORDER BY NULL This fixes code generation for queries containing `ORDER BY NULL`. Previously, the generated code would fail to compile. Author: Josh Rosen Closes #7179 from JoshRosen/generate-order-fixes and squashes the following commits: 6ef49a6 [Josh Rosen] Fix ORDER BY NULL 0036696 [Josh Rosen] Add regression test for SPARK-8782 (ORDER BY NULL) --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 1 + .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a64027e48a00b..9f6329bbda4ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -185,6 +185,7 @@ class CodeGenContext { // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" + case NullType => "0" case other => s"$c1.compare($c2)" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 82dc0e9ce5132..cc6af1ccc1cce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1451,4 +1451,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } } + + test("SPARK-8782: ORDER BY NULL") { + withTempTable("t") { + Seq((1, 2), (1, 2)).toDF("a", "b").registerTempTable("t") + checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) + } + } } From aa7bbc143844020e4711b3aa4ce75c1b7733a80d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 2 Jul 2015 21:38:21 -0500 Subject: [PATCH 119/274] [SPARK-6980] [CORE] Akka timeout exceptions indicate which conf controls them (RPC Layer) Latest changes after refactoring to the RPC layer. I rebased against trunk to make sure to get any recent changes since it had been a while. I wasn't crazy about the name `ConfigureTimeout` and `RpcTimeout` seemed to fit better, but I'm open to suggestions! I ran most of the tests and they pass, but others would get stuck with "WARN TaskSchedulerImpl: Initial job has not accepted any resources". I think its just my machine, so I'd though I would push what I have anyway. Still left to do: * I only added a couple unit tests so far, there are probably some more cases to test * Make sure all uses require a `RpcTimeout` * Right now, both the `ask` and `Await.result` use the same timeout, should we differentiate between these in the TimeoutException message? * I wrapped `Await.result` in `RpcTimeout`, should we also wrap `Await.ready`? * Proper scoping of classes and methods hardmettle, feel free to help out with any of these! Author: Bryan Cutler Author: Harsh Gupta Author: BryanCutler Closes #6205 from BryanCutler/configTimeout-6980 and squashes the following commits: 46c8d48 [Bryan Cutler] [SPARK-6980] Changed RpcEnvSuite test to never reply instead of just sleeping, to avoid possible sync issues 06afa53 [Bryan Cutler] [SPARK-6980] RpcTimeout class extends Serializable, was causing error in MasterSuite 7bb70f1 [Bryan Cutler] Merge branch 'master' into configTimeout-6980 dbd5f73 [Bryan Cutler] [SPARK-6980] Changed RpcUtils askRpcTimeout and lookupRpcTimeout scope to private[spark] and improved deprecation warning msg 4e89c75 [Bryan Cutler] [SPARK-6980] Missed one usage of deprecated RpcUtils.askTimeout in YarnSchedulerBackend although it is not being used, and fixed SparkConfSuite UT to not use deprecated RpcUtils functions 6a1c50d [Bryan Cutler] [SPARK-6980] Minor cleanup of test case 7f4d78e [Bryan Cutler] [SPARK-6980] Fixed scala style checks 287059a [Bryan Cutler] [SPARK-6980] Removed extra import in AkkaRpcEnvSuite 3d8b1ff [Bryan Cutler] [SPARK-6980] Cleaned up imports in AkkaRpcEnvSuite 3a168c7 [Bryan Cutler] [SPARK-6980] Rewrote Akka RpcTimeout UTs in RpcEnvSuite 7636189 [Bryan Cutler] [SPARK-6980] Fixed call to askWithReply in DAGScheduler to use RpcTimeout - this was being compiled by auto-tupling and changing the message type of BlockManagerHeartbeat be11c4e [Bryan Cutler] Merge branch 'master' into configTimeout-6980 039afed [Bryan Cutler] [SPARK-6980] Corrected import organization 218aa50 [Bryan Cutler] [SPARK-6980] Corrected issues from feedback fadaf6f [Bryan Cutler] [SPARK-6980] Put back in deprecated RpcUtils askTimeout and lookupTimout to fix MiMa errors fa6ed82 [Bryan Cutler] [SPARK-6980] Had to increase timeout on positive test case because a processor slowdown could trigger an Future TimeoutException b05d449 [Bryan Cutler] [SPARK-6980] Changed constructor to use val duration instead of getter function, changed name of string property from conf to timeoutProp for consistency c6cfd33 [Bryan Cutler] [SPARK-6980] Changed UT ask message timeout to explicitly intercept a SparkException 1394de6 [Bryan Cutler] [SPARK-6980] Moved MessagePrefix to createRpcTimeoutException directly 1517721 [Bryan Cutler] [SPARK-6980] RpcTimeout object scope should be private[spark] 2206b4d [Bryan Cutler] [SPARK-6980] Added unit test for ask then immediat awaitReply 1b9beab [Bryan Cutler] [SPARK-6980] Cleaned up import ordering 08f5afc [Bryan Cutler] [SPARK-6980] Added UT for constructing RpcTimeout with default value d3754d1 [Bryan Cutler] [SPARK-6980] Added akkaConf to prevent dead letter logging 995d196 [Bryan Cutler] [SPARK-6980] Cleaned up import ordering, comments, spacing from PR feedback 7774d56 [Bryan Cutler] [SPARK-6980] Cleaned up UT imports 4351c48 [Bryan Cutler] [SPARK-6980] Added UT for addMessageIfTimeout, cleaned up UTs 1607a5f [Bryan Cutler] [SPARK-6980] Changed addMessageIfTimeout to PartialFunction, cleanup from PR comments 2f94095 [Bryan Cutler] [SPARK-6980] Added addMessageIfTimeout for when a Future is completed with TimeoutException 235919b [Bryan Cutler] [SPARK-6980] Resolved conflicts after master merge c07d05c [Bryan Cutler] Merge branch 'master' into configTimeout-6980-tmp b7fb99f [BryanCutler] Merge pull request #2 from hardmettle/configTimeoutUpdates_6980 4be3a8d [Harsh Gupta] Modifying loop condition to find property match 0ee5642 [Harsh Gupta] Changing the loop condition to halt at the first match in the property list for RpcEnv exception catch f74064d [Harsh Gupta] Retrieving properties from property list using iterator and while loop instead of chained functions a294569 [Bryan Cutler] [SPARK-6980] Added creation of RpcTimeout with Seq of property keys 23d2f26 [Bryan Cutler] [SPARK-6980] Fixed await result not being handled by RpcTimeout 49f9f04 [Bryan Cutler] [SPARK-6980] Minor cleanup and scala style fix 5b59a44 [Bryan Cutler] [SPARK-6980] Added some RpcTimeout unit tests 78a2c0a [Bryan Cutler] [SPARK-6980] Using RpcTimeout.awaitResult for future in AppClient now 97523e0 [Bryan Cutler] [SPARK-6980] Akka ask timeout description refactored to RPC layer --- .../spark/deploy/worker/ui/WorkerWebUI.scala | 2 +- .../org/apache/spark/rpc/RpcEndpointRef.scala | 17 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 112 +++++++++++++++++- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 15 ++- .../apache/spark/scheduler/DAGScheduler.scala | 3 +- .../cluster/YarnSchedulerBackend.scala | 2 +- .../spark/storage/BlockManagerMaster.scala | 14 +-- .../org/apache/spark/util/AkkaUtils.scala | 19 ++- .../org/apache/spark/util/RpcUtils.scala | 20 +++- .../org/apache/spark/SparkConfSuite.scala | 4 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 97 ++++++++++++++- 11 files changed, 258 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index b3bb5f911dbd7..334a5b10142aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -38,7 +38,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - private[ui] val timeout = RpcUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 69181edb9ad44..6ae47894598be 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -17,8 +17,7 @@ package org.apache.spark.rpc -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.Future import scala.reflect.ClassTag import org.apache.spark.util.RpcUtils @@ -32,7 +31,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) private[this] val maxRetries = RpcUtils.numRetries(conf) private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) /** * return the address for the [[RpcEndpointRef]] @@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to @@ -91,7 +90,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts var attempts = 0 var lastException: Exception = null @@ -99,7 +98,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) attempts += 1 try { val future = ask[T](message, timeout) - val result = Await.result(future, timeout) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -110,10 +109,14 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) lastException = e logWarning(s"Error sending message [message = $message] in $attempts attempts", e) } - Thread.sleep(retryWaitMs) + + if (attempts < maxRetries) { + Thread.sleep(retryWaitMs) + } } throw new SparkException( s"Error sending message [message = $message]", lastException) } + } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 3b6938ec639c3..1709bdf560b6f 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -18,8 +18,10 @@ package org.apache.spark.rpc import java.net.URI +import java.util.concurrent.TimeoutException -import scala.concurrent.{Await, Future} +import scala.concurrent.{Awaitable, Await, Future} +import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.{SecurityManager, SparkConf} @@ -66,7 +68,7 @@ private[spark] object RpcEnv { */ private[spark] abstract class RpcEnv(conf: SparkConf) { - private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf) + private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf) /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement @@ -94,7 +96,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. */ def setupEndpointRefByURI(uri: String): RpcEndpointRef = { - Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) } /** @@ -184,3 +186,107 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 31ebe5ac5bca3..f2d87f68341af 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future -import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -214,8 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] ( override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). + // this is just in case there is a timeout from creating the future in resolveOne, we want the + // exception to indicate the conf that determines the timeout + recover(defaultLookupTimeout.addMessageIfTimeout) } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { @@ -295,8 +297,8 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { - actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { // The function will run in the calling thread, so it should be short and never block. case msg @ AkkaMessage(message, reply) => if (reply) { @@ -307,7 +309,8 @@ private[akka] class AkkaRpcEndpointRef( } case AkkaFailure(e) => Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T] + }(ThreadUtils.sameThread).mapTo[T]. + recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } override def toString: String = s"${getClass.getSimpleName}($actorRef)" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a7cf0c23d9613..6841fa835747f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ @@ -188,7 +189,7 @@ class DAGScheduler( blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( - BlockManagerHeartbeat(blockManagerId), 600 seconds) + BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } // Called by TaskScheduler when an executor fails. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 190ff61d689d1..bc67abb5df446 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -46,7 +46,7 @@ private[spark] abstract class YarnSchedulerBackend( private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - private implicit val askTimeout = RpcUtils.askTimeout(sc.conf) + private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7cdae22b0e253..f70f701494dbf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -33,7 +33,7 @@ class BlockManagerMaster( isDriver: Boolean) extends Logging { - val timeout = RpcUtils.askTimeout(conf) + val timeout = RpcUtils.askRpcTimeout(conf) /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { @@ -106,7 +106,7 @@ class BlockManagerMaster( logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -118,7 +118,7 @@ class BlockManagerMaster( logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -132,7 +132,7 @@ class BlockManagerMaster( s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -176,8 +176,8 @@ class BlockManagerMaster( CanBuildFrom[Iterable[Future[Option[BlockStatus]]], Option[BlockStatus], Iterable[Option[BlockStatus]]]] - val blockStatus = Await.result( - Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread), timeout) + val blockStatus = timeout.awaitResult( + Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread)) if (blockStatus == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } @@ -199,7 +199,7 @@ class BlockManagerMaster( askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) - Await.result(future, timeout) + timeout.awaitResult(future) } /** diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 96aa2fe164703..c179833e5b06a 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -18,8 +18,6 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap -import scala.concurrent.Await -import scala.concurrent.duration.FiniteDuration import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -28,6 +26,7 @@ import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} +import org.apache.spark.rpc.RpcTimeout /** * Various utility classes for working with Akka. @@ -147,7 +146,7 @@ private[spark] object AkkaUtils extends Logging { def askWithReply[T]( message: Any, actor: ActorRef, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) } @@ -160,7 +159,7 @@ private[spark] object AkkaUtils extends Logging { actor: ActorRef, maxAttempts: Int, retryInterval: Long, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts if (actor == null) { throw new SparkException(s"Error sending message [message = $message]" + @@ -171,8 +170,8 @@ private[spark] object AkkaUtils extends Logging { while (attempts < maxAttempts) { attempts += 1 try { - val future = actor.ask(message)(timeout) - val result = Await.result(future, timeout) + val future = actor.ask(message)(timeout.duration) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -198,9 +197,9 @@ private[spark] object AkkaUtils extends Logging { val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = RpcUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupRpcTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def makeExecutorRef( @@ -212,9 +211,9 @@ private[spark] object AkkaUtils extends Logging { val executorActorSystemName = SparkEnv.executorActorSystemName Utils.checkHost(host, "Expected hostname") val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = RpcUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupRpcTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def protocol(actorSystem: ActorSystem): String = { diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index f16cc8e7e42c6..7578a3b1d85f2 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,11 +17,11 @@ package org.apache.spark.util -import scala.concurrent.duration._ +import scala.concurrent.duration.FiniteDuration import scala.language.postfixOps import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} object RpcUtils { @@ -47,14 +47,22 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC ask operations. */ + private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = { + RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") + } + + @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0") def askTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.askTimeout", - conf.get("spark.network.timeout", "120s")) seconds + askRpcTimeout(conf).duration } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ + private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { + RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") + } + + @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0") def lookupTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.lookupTimeout", - conf.get("spark.network.timeout", "120s")) seconds + lookupRpcTimeout(conf).duration } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 9fbaeb33f97cd..90cb7da94e88a 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -260,10 +260,10 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst assert(RpcUtils.retryWaitMs(conf) === 2L) conf.set("spark.akka.askTimeout", "3") - assert(RpcUtils.askTimeout(conf) === (3 seconds)) + assert(RpcUtils.askRpcTimeout(conf).duration === (3 seconds)) conf.set("spark.akka.lookupTimeout", "4") - assert(RpcUtils.lookupTimeout(conf) === (4 seconds)) + assert(RpcUtils.lookupRpcTimeout(conf).duration === (4 seconds)) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 1f0aa759b08da..6ceafe4337747 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -155,16 +155,21 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) val conf = new SparkConf() + val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") val anotherEnv = createRpcEnv(conf, "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { - val e = intercept[Exception] { - rpcEndpointRef.askWithRetry[String]("hello", 1 millis) + // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause + val e = intercept[SparkException] { + rpcEndpointRef.askWithRetry[String]("hello", new RpcTimeout(1 millis, shortProp)) } - assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) + // The SparkException cause should be a RpcTimeoutException with message indicating the + // controlling timeout property + assert(e.getCause.isInstanceOf[RpcTimeoutException]) + assert(e.getCause.getMessage.contains(shortProp)) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -539,6 +544,92 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("construct RpcTimeout with conf property") { + val conf = new SparkConf + + val testProp = "spark.ask.test.timeout" + val testDurationSeconds = 30 + val secondaryProp = "spark.ask.secondary.timeout" + + conf.set(testProp, s"${testDurationSeconds}s") + conf.set(secondaryProp, "100s") + + // Construct RpcTimeout with a single property + val rt1 = RpcTimeout(conf, testProp) + assert( testDurationSeconds === rt1.duration.toSeconds ) + + // Construct RpcTimeout with prioritized list of properties + val rt2 = RpcTimeout(conf, Seq("spark.ask.invalid.timeout", testProp, secondaryProp), "1s") + assert( testDurationSeconds === rt2.duration.toSeconds ) + + // Construct RpcTimeout with default value, + val defaultProp = "spark.ask.default.timeout" + val defaultDurationSeconds = 1 + val rt3 = RpcTimeout(conf, Seq(defaultProp), defaultDurationSeconds.toString + "s") + assert( defaultDurationSeconds === rt3.duration.toSeconds ) + assert( rt3.timeoutProp.contains(defaultProp) ) + + // Try to construct RpcTimeout with an unconfigured property + intercept[NoSuchElementException] { + RpcTimeout(conf, "spark.ask.invalid.timeout") + } + } + + test("ask a message timeout on Future using RpcTimeout") { + case class NeverReply(msg: String) + + val rpcEndpointRef = env.setupEndpoint("ask-future", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.reply(msg) + case _: NeverReply => + } + }) + + val longTimeout = new RpcTimeout(1 second, "spark.rpc.long.timeout") + val shortTimeout = new RpcTimeout(10 millis, "spark.rpc.short.timeout") + + // Ask with immediate response, should complete successfully + val fut1 = rpcEndpointRef.ask[String]("hello", longTimeout) + val reply1 = longTimeout.awaitResult(fut1) + assert("hello" === reply1) + + // Ask with a delayed response and wait for response immediately that should timeout + val fut2 = rpcEndpointRef.ask[String](NeverReply("doh"), shortTimeout) + val reply2 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut2) + }.getMessage + + // RpcTimeout.awaitResult should have added the property to the TimeoutException message + assert(reply2.contains(shortTimeout.timeoutProp)) + + // Ask with delayed response and allow the Future to timeout before Await.result + val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout) + + // Allow future to complete with failure using plain Await.result, this will return + // once the future is complete to verify addMessageIfTimeout was invoked + val reply3 = + intercept[RpcTimeoutException] { + Await.result(fut3, 200 millis) + }.getMessage + + // When the future timed out, the recover callback should have used + // RpcTimeout.addMessageIfTimeout to add the property to the TimeoutException message + assert(reply3.contains(shortTimeout.timeoutProp)) + + // Use RpcTimeout.awaitResult to process Future, since it has already failed with + // RpcTimeoutException, the same RpcTimeoutException should be thrown + val reply4 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut3) + }.getMessage + + // Ensure description is not in message twice after addMessageIfTimeout and awaitResult + assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) + } + } class UnserializableClass From 1a7a7d7d579c5cba104daffbda977915802bf9b9 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 2 Jul 2015 20:37:31 -0700 Subject: [PATCH 120/274] [SPARK-8213][SQL]Add function factorial Author: zhichao.li Closes #6822 from zhichao-li/factorial and squashes the following commits: 26edf4f [zhichao.li] add factorial --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 80 ++++++++++++++++++- .../expressions/MathFunctionsSuite.scala | 15 +++- .../org/apache/spark/sql/functions.scala | 16 ++++ .../spark/sql/MathExpressionsSuite.scala | 13 ++- 5 files changed, 122 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e7e4d1c4efe18..9163b032adee4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -113,6 +113,7 @@ object FunctionRegistry { expression[Exp]("exp"), expression[Expm1]("expm1"), expression[Floor]("floor"), + expression[Factorial]("factorial"), expression[Hypot]("hypot"), expression[Hex]("hex"), expression[Logarithm]("log"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 035980da568d3..701ab9912adba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -21,8 +21,10 @@ import java.lang.{Long => JLong} import java.util.Arrays import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{StringType} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{DataType, DoubleType, LongType, IntegerType} import org.apache.spark.unsafe.types.UTF8String /** @@ -159,6 +161,82 @@ case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXP case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") +object Factorial { + + def factorial(n: Int): Long = { + if (n < factorials.length) factorials(n) else Long.MaxValue + } + + private val factorials: Array[Long] = Array[Long]( + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800L, + 87178291200L, + 1307674368000L, + 20922789888000L, + 355687428096000L, + 6402373705728000L, + 121645100408832000L, + 2432902008176640000L + ) +} + +case class Factorial(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[DataType] = Seq(IntegerType) + + override def dataType: DataType = LongType + + override def foldable: Boolean = child.foldable + + // If the value not in the range of [0, 20], it still will be null, so set it to be true here. + override def nullable: Boolean = true + + override def toString: String = s"factorial($child)" + + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val input = evalE.asInstanceOf[Integer] + if (input > 20 || input < 0) { + null + } else { + Factorial.factorial(input) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + if (${eval.primitive} > 20 || ${eval.primitive} < 0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = + org.apache.spark.sql.catalyst.expressions.Factorial.factorial(${eval.primitive}); + } + } + """ + } +} + case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") case class Log2(child: Expression) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index aa27fe3cd5564..8457864d1782d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import com.google.common.math.LongMath + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType} +import org.apache.spark.sql.types.{DataType, LongType} +import org.apache.spark.sql.types.{IntegerType, DoubleType} class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -157,6 +160,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Floor, math.floor) } + test("factorial") { + val dataLong = (0 to 20) + dataLong.foreach { value => + checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow) + } + checkEvaluation((Literal.create(null, IntegerType)), null, create_row(null)) + checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) + checkEvaluation(Factorial(Literal(21)), null, EmptyRow) + } + test("rint") { testUnary(Rint, math.rint) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4ee1fb8374b07..0d5d49c3dd1d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1022,6 +1022,22 @@ object functions { */ def expm1(columnName: String): Column = expm1(Column(columnName)) + /** + * Computes the factorial of the given value. + * + * @group math_funcs + * @since 1.5.0 + */ + def factorial(e: Column): Column = Factorial(e.expr) + + /** + * Computes the factorial of the given column. + * + * @group math_funcs + * @since 1.5.0 + */ + def factorial(columnName: String): Column = factorial(Column(columnName)) + /** * Computes the floor of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 4c5696deaff81..dc8f994adbd39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} - private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) @@ -183,6 +182,18 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(floor, math.floor) } + test("factorial") { + val df = (0 to 5).map(i => (i, i)).toDF("a", "b") + checkAnswer( + df.select(factorial('a)), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + checkAnswer( + df.selectExpr("factorial(a)"), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + } + test("rint") { testOneToOneMathFunction(rint, math.rint) } From dfd8bac8f5b4f2b733c1ddd58e53ee0ba431e6b3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 2 Jul 2015 20:47:04 -0700 Subject: [PATCH 121/274] Minor style fix for the previous commit. --- .../spark/sql/catalyst/expressions/math.scala | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 701ab9912adba..273a6c5016577 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -21,10 +21,8 @@ import java.lang.{Long => JLong} import java.util.Arrays import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.{StringType} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{DataType, DoubleType, LongType, IntegerType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -192,8 +190,7 @@ object Factorial { ) } -case class Factorial(child: Expression) - extends UnaryExpression with ExpectsInputTypes { +case class Factorial(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -204,8 +201,6 @@ case class Factorial(child: Expression) // If the value not in the range of [0, 20], it still will be null, so set it to be true here. override def nullable: Boolean = true - override def toString: String = s"factorial($child)" - override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { @@ -372,8 +367,8 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { var len = 0 do { len += 1 - value(value.length - len) = Character.toUpperCase(Character - .forDigit((numBuf & 0xF).toInt, 16)).toByte + value(value.length - len) = + Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) From 20a4d7dbd18fd4d1e3fb9324749453123714f99f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 2 Jul 2015 21:30:57 -0700 Subject: [PATCH 122/274] [SPARK-8501] [SQL] Avoids reading schema from empty ORC files ORC writes empty schema (`struct<>`) to ORC files containing zero rows. This is OK for Hive since the table schema is managed by the metastore. But it causes trouble when reading raw ORC files via Spark SQL since we have to discover the schema from the files. Notice that the ORC data source always avoids writing empty ORC files, but it's still problematic when reading Hive tables which contain empty part-files. Author: Cheng Lian Closes #7199 from liancheng/spark-8501 and squashes the following commits: bb8cd95 [Cheng Lian] Addresses comments a290221 [Cheng Lian] Avoids reading schema from empty ORC files --- .../spark/sql/hive/orc/OrcFileOperator.scala | 60 +++++++++++++++---- .../spark/sql/hive/orc/OrcRelation.scala | 44 ++++++++------ .../spark/sql/hive/orc/OrcQuerySuite.scala | 55 ++++++++++++++--- .../spark/sql/hive/orc/OrcSourceSuite.scala | 28 ++++----- 4 files changed, 135 insertions(+), 52 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index e3ab9442b4821..0f9a1a6ef3b27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -24,30 +24,70 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveMetastoreTypes import org.apache.spark.sql.types.StructType private[orc] object OrcFileOperator extends Logging { - def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { + /** + * Retrieves a ORC file reader from a given path. The path can point to either a directory or a + * single ORC file. If it points to an directory, it picks any non-empty ORC file within that + * directory. + * + * The reader returned by this method is mainly used for two purposes: + * + * 1. Retrieving file metadata (schema and compression codecs, etc.) + * 2. Read the actual file content (in this case, the given path should point to the target file) + * + * @note As recorded by SPARK-8501, ORC writes an empty schema (struct<> + logInfo( + s"ORC file $path has empty schema, it probably contains no rows. " + + "Trying to read another ORC file to figure out the schema.") + false + case _ => true + } + } + val conf = config.getOrElse(new Configuration) - val fspath = new Path(pathStr) - val fs = fspath.getFileSystem(conf) - val orcFiles = listOrcFiles(pathStr, conf) - logDebug(s"Creating ORC Reader from ${orcFiles.head}") - // TODO Need to consider all files when schema evolution is taken into account. - OrcFile.createReader(fs, orcFiles.head) + val fs = { + val hdfsPath = new Path(basePath) + hdfsPath.getFileSystem(conf) + } + + listOrcFiles(basePath, conf).iterator.map { path => + path -> OrcFile.createReader(fs, path) + }.collectFirst { + case (path, reader) if isWithNonEmptySchema(path, reader) => reader + } } def readSchema(path: String, conf: Option[Configuration]): StructType = { - val reader = getFileReader(path, conf) + val reader = getFileReader(path, conf).getOrElse { + throw new AnalysisException( + s"Failed to discover schema from ORC files stored in $path. " + + "Probably there are either no ORC files or only empty ORC files.") + } val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName logDebug(s"Reading schema from file $path, got Hive schema string: $schema") HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] } - def getObjectInspector(path: String, conf: Option[Configuration]): StructObjectInspector = { - getFileReader(path, conf).getObjectInspector.asInstanceOf[StructObjectInspector] + def getObjectInspector( + path: String, conf: Option[Configuration]): Option[StructObjectInspector] = { + getFileReader(path, conf).map(_.getObjectInspector.asInstanceOf[StructObjectInspector]) } def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 300f83d914ea4..9dc9fbb78e01f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -242,26 +242,34 @@ private[orc] case class OrcTableScan( nonPartitionKeyAttrs: Seq[(Attribute, Int)], mutableRow: MutableRow): Iterator[InternalRow] = { val deserializer = new OrcSerde - val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) - val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { - case (attr, ordinal) => - soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal - }.unzip - val unwrappers = fieldRefs.map(unwrapperFor) - // Map each tuple to a row object - iterator.map { value => - val raw = deserializer.deserialize(value) - var i = 0 - while (i < fieldRefs.length) { - val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) - if (fieldValue == null) { - mutableRow.setNullAt(fieldOrdinals(i)) - } else { - unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + val maybeStructOI = OrcFileOperator.getObjectInspector(path, Some(conf)) + + // SPARK-8501: ORC writes an empty schema ("struct<>") to an ORC file if the file contains zero + // rows, and thus couldn't give a proper ObjectInspector. In this case we just return an empty + // partition since we know that this file is empty. + maybeStructOI.map { soi => + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + case (attr, ordinal) => + soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + }.unzip + val unwrappers = fieldRefs.map(unwrapperFor) + // Map each tuple to a row object + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 } - i += 1 + mutableRow: InternalRow } - mutableRow: InternalRow + }.getOrElse { + Iterator.empty } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 267d22c6b5f1e..ca131faaeef05 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -23,10 +23,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.InternalRow -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -170,7 +167,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("Default compression options for writing to an ORC file") { withOrcFile((1 to 100).map(i => (i, s"val_$i"))) { file => assertResult(CompressionKind.ZLIB) { - OrcFileOperator.getFileReader(file).getCompression + OrcFileOperator.getFileReader(file).get.getCompression } } } @@ -183,21 +180,21 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY") withOrcFile(data) { file => assertResult(CompressionKind.SNAPPY) { - OrcFileOperator.getFileReader(file).getCompression + OrcFileOperator.getFileReader(file).get.getCompression } } conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "NONE") withOrcFile(data) { file => assertResult(CompressionKind.NONE) { - OrcFileOperator.getFileReader(file).getCompression + OrcFileOperator.getFileReader(file).get.getCompression } } conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "LZO") withOrcFile(data) { file => assertResult(CompressionKind.LZO) { - OrcFileOperator.getFileReader(file).getCompression + OrcFileOperator.getFileReader(file).get.getCompression } } } @@ -289,4 +286,48 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { List(Row("same", "run_5", 100))) } } + + test("SPARK-8501: Avoids discovery schema from empty ORC files") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("empty_orc") { + withTempTable("empty", "single") { + sqlContext.sql( + s"""CREATE TABLE empty_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '$path' + """.stripMargin) + + val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) + emptyDF.registerTempTable("empty") + + // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because + // Spark SQL ORC data source always avoids write empty ORC files. + sqlContext.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM empty + """.stripMargin) + + val errorMessage = intercept[AnalysisException] { + sqlContext.read.format("orc").load(path) + }.getMessage + + assert(errorMessage.contains("Failed to discover schema from ORC files")) + + val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) + singleRowDF.registerTempTable("single") + + sqlContext.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = sqlContext.read.format("orc").load(path) + assert(df.schema === singleRowDF.schema.asNullable) + checkAnswer(df, singleRowDF) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index a0cdd0db42d65..82e08caf46457 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -43,14 +43,8 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { orcTableDir.mkdir() import org.apache.spark.sql.hive.test.TestHive.implicits._ - // Originally we were using a 10-row RDD for testing. However, when default parallelism is - // greater than 10 (e.g., running on a node with 32 cores), this RDD contains empty partitions, - // which result in empty ORC files. Unfortunately, ORC doesn't handle empty files properly and - // causes build failure on Jenkins, which happens to have 32 cores. Please refer to SPARK-8501 - // for more details. To workaround this issue before fixing SPARK-8501, we simply increase row - // number in this RDD to avoid empty partitions. sparkContext - .makeRDD(1 to 100) + .makeRDD(1 to 10) .map(i => OrcData(i, s"part-$i")) .toDF() .registerTempTable(s"orc_temp_table") @@ -76,35 +70,35 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { } test("create temporary orc table") { - checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(100)) + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) checkAnswer( sql("SELECT * FROM normal_orc_source"), - (1 to 100).map(i => Row(i, s"part-$i"))) + (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT * FROM normal_orc_source where intField > 5"), - (6 to 100).map(i => Row(i, s"part-$i"))) + (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), - (1 to 100).map(i => Row(1, s"part-$i"))) + (1 to 10).map(i => Row(1, s"part-$i"))) } test("create temporary orc table as") { - checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(100)) + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) checkAnswer( sql("SELECT * FROM normal_orc_source"), - (1 to 100).map(i => Row(i, s"part-$i"))) + (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT * FROM normal_orc_source WHERE intField > 5"), - (6 to 100).map(i => Row(i, s"part-$i"))) + (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), - (1 to 100).map(i => Row(1, s"part-$i"))) + (1 to 10).map(i => Row(1, s"part-$i"))) } test("appending insert") { @@ -112,7 +106,7 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_source"), - (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 100).flatMap { i => + (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => Seq.fill(2)(Row(i, s"part-$i")) }) } @@ -125,7 +119,7 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_as_source"), - (6 to 100).map(i => Row(i, s"part-$i"))) + (6 to 10).map(i => Row(i, s"part-$i"))) } } From a59d14f623633c7aef97991341b587c11ca42328 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 2 Jul 2015 21:45:25 -0700 Subject: [PATCH 123/274] [SPARK-8801][SQL] Support TypeCollection in ExpectsInputTypes This patch adds a new TypeCollection AbstractDataType that can be used by expressions to specify more than one expected input types. Author: Reynold Xin Closes #7202 from rxin/type-collection and squashes the following commits: c714ca1 [Reynold Xin] Fixed style. a0c0d12 [Reynold Xin] Fixed bugs and unit tests. d8b8ae7 [Reynold Xin] Added TypeCollection. --- .../catalyst/analysis/HiveTypeCoercion.scala | 47 +++++++++++++--- .../spark/sql/types/AbstractDataType.scala | 50 ++++++++++++++--- .../apache/spark/sql/types/ArrayType.scala | 6 +- .../org/apache/spark/sql/types/DataType.scala | 4 +- .../apache/spark/sql/types/DecimalType.scala | 4 ++ .../org/apache/spark/sql/types/MapType.scala | 4 ++ .../apache/spark/sql/types/StructType.scala | 8 ++- .../analysis/HiveTypeCoercionSuite.scala | 55 +++++++++++++------ 8 files changed, 140 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 0bc893224026e..6006e7bf00c13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import javax.annotation.Nullable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule @@ -713,39 +715,68 @@ object HiveTypeCoercion { case e: ExpectsInputTypes => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => - implicitCast(in, expected) + // If we cannot do the implicit cast, just use the original input. + implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) } /** - * If needed, cast the expression into the expected type. - * If the implicit cast is not allowed, return the expression itself. + * Given an expected data type, try to cast the expression and return the cast expression. + * + * If the expression already fits the input type, we simply return the expression itself. + * If the expression has an incompatible type that cannot be implicitly cast, return None. */ - def implicitCast(e: Expression, expectedType: AbstractDataType): Expression = { + def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { val inType = e.dataType - (inType, expectedType) match { + + // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. + // We wrap immediately an Option after this. + @Nullable val ret: Expression = (inType, expectedType) match { + + // If the expected type is already a parent of the input type, no need to cast. + case _ if expectedType.isParentOf(inType) => e + // Cast null type (usually from null literals) into target types - case (NullType, target: DataType) => Cast(e, target.defaultConcreteType) + case (NullType, target) => Cast(e, target.defaultConcreteType) // Implicit cast among numeric types + // If input is decimal, and we expect a decimal type, just use the input. + case (_: DecimalType, DecimalType) => e + // If input is a numeric type but not decimal, and we expect a decimal type, + // cast the input to unlimited precision decimal. + case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => + Cast(e, DecimalType.Unlimited) + // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) + case (_: NumericType, target: NumericType) => e // Implicit cast between date time types case (DateType, TimestampType) => Cast(e, TimestampType) case (TimestampType, DateType) => Cast(e, DateType) // Implicit cast from/to string - case (StringType, NumericType) => Cast(e, DoubleType) + case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited) case (StringType, target: NumericType) => Cast(e, target) case (StringType, DateType) => Cast(e, DateType) case (StringType, TimestampType) => Cast(e, TimestampType) case (StringType, BinaryType) => Cast(e, BinaryType) case (any, StringType) if any != StringType => Cast(e, StringType) + // Type collection. + // First see if we can find our input type in the type collection. If we can, then just + // use the current expression; otherwise, find the first one we can implicitly cast. + case (_, TypeCollection(types)) => + if (types.exists(_.isParentOf(inType))) { + e + } else { + types.flatMap(implicitCast(e, _)).headOption.orNull + } + // Else, just return the same input expression - case _ => e + case _ => null } + Option(ret) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 43e2f8a46e62e..e5dc99fb625d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -28,7 +28,45 @@ import org.apache.spark.util.Utils * A non-concrete data type, reserved for internal uses. */ private[sql] abstract class AbstractDataType { + /** + * The default concrete type to use if we want to cast a null literal into this type. + */ private[sql] def defaultConcreteType: DataType + + /** + * Returns true if this data type is a parent of the `childCandidate`. + */ + private[sql] def isParentOf(childCandidate: DataType): Boolean +} + + +/** + * A collection of types that can be used to specify type constraints. The sequence also specifies + * precedence: an earlier type takes precedence over a latter type. + * + * {{{ + * TypeCollection(StringType, BinaryType) + * }}} + * + * This means that we prefer StringType over BinaryType if it is possible to cast to StringType. + */ +private[sql] class TypeCollection(private val types: Seq[DataType]) extends AbstractDataType { + require(types.nonEmpty, s"TypeCollection ($types) cannot be empty") + + private[sql] override def defaultConcreteType: DataType = types.head + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = false +} + + +private[sql] object TypeCollection { + + def apply(types: DataType*): TypeCollection = new TypeCollection(types) + + def unapply(typ: AbstractDataType): Option[Seq[DataType]] = typ match { + case typ: TypeCollection => Some(typ.types) + case _ => None + } } @@ -61,7 +99,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType extends AbstractDataType { +private[sql] object NumericType { /** * Enables matching against NumericType for expressions: * {{{ @@ -70,12 +108,10 @@ private[sql] object NumericType extends AbstractDataType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - - private[sql] override def defaultConcreteType: DataType = IntegerType } -private[sql] object IntegralType extends AbstractDataType { +private[sql] object IntegralType { /** * Enables matching against IntegralType for expressions: * {{{ @@ -84,8 +120,6 @@ private[sql] object IntegralType extends AbstractDataType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] - - private[sql] override def defaultConcreteType: DataType = IntegerType } @@ -94,7 +128,7 @@ private[sql] abstract class IntegralType extends NumericType { } -private[sql] object FractionalType extends AbstractDataType { +private[sql] object FractionalType { /** * Enables matching against FractionalType for expressions: * {{{ @@ -103,8 +137,6 @@ private[sql] object FractionalType extends AbstractDataType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType] - - private[sql] override def defaultConcreteType: DataType = DoubleType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 81553e7fc91a8..8ea6cb14c360e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -26,7 +26,11 @@ object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) - override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) + private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[ArrayType] + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index c333fa70d1ef4..7d00047d08d74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -75,7 +75,9 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType - override def defaultConcreteType: DataType = this + private[sql] override def defaultConcreteType: DataType = this + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 06373a095b1b0..434fc037aad4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -86,6 +86,10 @@ object DecimalType extends AbstractDataType { private[sql] override def defaultConcreteType: DataType = Unlimited + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[DecimalType] + } + val Unlimited: DecimalType = DecimalType(None) private[sql] object Fixed { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 69c2119e23436..2b25617ec6655 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -71,6 +71,10 @@ object MapType extends AbstractDataType { private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType) + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[MapType] + } + /** * Construct a [[MapType]] object with the given key type and value type. * The `valueContainsNull` is true. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 6fedeabf23203..7e77b77e73940 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -301,7 +301,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } -object StructType { +object StructType extends AbstractDataType { + + private[sql] override def defaultConcreteType: DataType = new StructType + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[StructType] + } def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 498fd86a06fd9..60e727c6c7d4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -27,28 +27,47 @@ import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { test("implicit type cast") { - def shouldCast(from: DataType, to: AbstractDataType): Unit = { + def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) - assert(got.dataType === to.defaultConcreteType) + assert(got.map(_.dataType) == Option(expected), + s"Failed to cast $from to $to") } + shouldCast(NullType, NullType, NullType) + shouldCast(NullType, IntegerType, IntegerType) + shouldCast(NullType, DecimalType, DecimalType.Unlimited) + // TODO: write the entire implicit cast table out for test cases. - shouldCast(ByteType, IntegerType) - shouldCast(IntegerType, IntegerType) - shouldCast(IntegerType, LongType) - shouldCast(IntegerType, DecimalType.Unlimited) - shouldCast(LongType, IntegerType) - shouldCast(LongType, DecimalType.Unlimited) - - shouldCast(DateType, TimestampType) - shouldCast(TimestampType, DateType) - - shouldCast(StringType, IntegerType) - shouldCast(StringType, DateType) - shouldCast(StringType, TimestampType) - shouldCast(IntegerType, StringType) - shouldCast(DateType, StringType) - shouldCast(TimestampType, StringType) + shouldCast(ByteType, IntegerType, IntegerType) + shouldCast(IntegerType, IntegerType, IntegerType) + shouldCast(IntegerType, LongType, LongType) + shouldCast(IntegerType, DecimalType, DecimalType.Unlimited) + shouldCast(LongType, IntegerType, IntegerType) + shouldCast(LongType, DecimalType, DecimalType.Unlimited) + + shouldCast(DateType, TimestampType, TimestampType) + shouldCast(TimestampType, DateType, DateType) + + shouldCast(StringType, IntegerType, IntegerType) + shouldCast(StringType, DateType, DateType) + shouldCast(StringType, TimestampType, TimestampType) + shouldCast(IntegerType, StringType, StringType) + shouldCast(DateType, StringType, StringType) + shouldCast(TimestampType, StringType, StringType) + + shouldCast(StringType, BinaryType, BinaryType) + shouldCast(BinaryType, StringType, StringType) + + shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType) + + shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType) + shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType) + shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType) + + shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType) + shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) + shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) + shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) } test("tightest common bound for types") { From f743c79abe5a2fb66be32a896ea47e858569b0c7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 2 Jul 2015 22:09:07 -0700 Subject: [PATCH 124/274] [SPARK-8776] Increase the default MaxPermSize I am increasing the perm gen size to 256m. https://issues.apache.org/jira/browse/SPARK-8776 Author: Yin Huai Closes #7196 from yhuai/SPARK-8776 and squashes the following commits: 60901b4 [Yin Huai] Fix test. d44b713 [Yin Huai] Make sparkShell and hiveConsole use 256m PermGen size. 30aaf8e [Yin Huai] Increase the default PermGen size to 256m. --- .../org/apache/spark/launcher/AbstractCommandBuilder.java | 2 +- .../apache/spark/launcher/SparkSubmitCommandBuilderSuite.java | 2 +- project/SparkBuild.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 33d65d13f0d25..5e793a5c48775 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -136,7 +136,7 @@ void addPermGenSizeOpt(List cmd) { } } - cmd.add("-XX:MaxPermSize=128m"); + cmd.add("-XX:MaxPermSize=256m"); } void addOptionString(List cmd, String options) { diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 97043a76cc612..7329ac9f7fb8c 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -194,7 +194,7 @@ private void testCmdBuilder(boolean isDriver) throws Exception { if (isDriver) { assertEquals("-XX:MaxPermSize=256m", arg); } else { - assertEquals("-XX:MaxPermSize=128m", arg); + assertEquals("-XX:MaxPermSize=256m", arg); } } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5f389bcc9ceeb..3408c6d51ed4c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -206,7 +206,7 @@ object SparkBuild extends PomBuild { fork := true, outputStrategy in run := Some (StdoutOutput), - javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=1g"), + javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=256m"), sparkShell := { (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value @@ -299,7 +299,7 @@ object SQL { object Hive { lazy val settings = Seq( - javaOptions += "-XX:MaxPermSize=1g", + javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), // Multiple queries rely on the TestHive singleton. See comments there for more details. From 9b23e92c727881ff9038b4fe9643c49b96914159 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 2 Jul 2015 22:10:24 -0700 Subject: [PATCH 125/274] [SPARK-8803] handle special characters in elements in crosstab cc rxin Having back ticks or null as elements causes problems. Since elements become column names, we have to drop them from the element as back ticks are special characters. Having null throws exceptions, we could replace them with empty strings. Handling back ticks should be improved for 1.5 Author: Burak Yavuz Closes #7201 from brkyvz/weird-ct-elements and squashes the following commits: e06b840 [Burak Yavuz] fix scalastyle 93a0d3f [Burak Yavuz] added tests for NaN and Infinity 9dba6ce [Burak Yavuz] address cr1 db71dbd [Burak Yavuz] handle special characters in elements in crosstab --- .../spark/sql/DataFrameNaFunctions.scala | 2 +- .../spark/sql/DataFrameStatFunctions.scala | 3 ++ .../sql/execution/stat/StatFunctions.scala | 20 ++++++++++--- .../apache/spark/sql/DataFrameStatSuite.scala | 30 +++++++++++++++++++ 4 files changed, 50 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index b4c2daa055868..8681a56c82f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. */ private def fillCol[T](col: StructField, replacement: T): Column = { - coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) + coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index edb9ed7bba56a..587869e57f96e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -78,6 +78,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The first column of each row will be the distinct values of `col1` and the column names will * be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts * will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. + * Null elements will be replaced by "null", and back ticks will be dropped from elements if they + * exist. + * * * @param col1 The name of the first column. Distinct items will make the first item of * each row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 23ddfa9839e5e..00231d65a7d54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -110,8 +110,12 @@ private[sql] object StatFunctions extends Logging { logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + "the pairs. Please try reducing the amount of distinct items in your columns.") } + def cleanElement(element: Any): String = { + if (element == null) "null" else element.toString + } // get the distinct values of column 2, so that we can make them the column names - val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val distinctCol2: Map[Any, Int] = + counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") @@ -121,15 +125,23 @@ private[sql] object StatFunctions extends Logging { // row.get(0) is column 1 // row.get(1) is column 2 // row.get(2) is the frequency - countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) + val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get + countsRow.setLong(columnIndex + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts - countsRow.update(0, UTF8String.fromString(col1Item.toString)) + countsRow.update(0, UTF8String.fromString(cleanElement(col1Item))) countsRow }.toSeq + // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept + // special keywords and `.`, wrap the column names in ``. + def cleanColumnName(name: String): String = { + name.replace("`", "") + } // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in // SPARK-8681. We need to explicitly sort by the column index and assign the column names. - val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType)) + val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r => + StructField(cleanColumnName(r._1.toString), LongType) + } val schema = StructType(StructField(tableName, StringType) +: headerNames) new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 765094da6bda7..7ba4ba73e0cc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -85,6 +85,36 @@ class DataFrameStatSuite extends SparkFunSuite { } } + test("special crosstab elements (., '', null, ``)") { + val data = Seq( + ("a", Double.NaN, "ho"), + (null, 2.0, "ho"), + ("a.b", Double.NegativeInfinity, ""), + ("b", Double.PositiveInfinity, "`ha`"), + ("a", 1.0, null) + ) + val df = data.toDF("1", "2", "3") + val ct1 = df.stat.crosstab("1", "2") + // column fields should be 1 + distinct elements of second column + assert(ct1.schema.fields.length === 6) + assert(ct1.collect().length === 4) + val ct2 = df.stat.crosstab("1", "3") + assert(ct2.schema.fields.length === 5) + assert(ct2.schema.fieldNames.contains("ha")) + assert(ct2.collect().length === 4) + val ct3 = df.stat.crosstab("3", "2") + assert(ct3.schema.fields.length === 6) + assert(ct3.schema.fieldNames.contains("NaN")) + assert(ct3.schema.fieldNames.contains("Infinity")) + assert(ct3.schema.fieldNames.contains("-Infinity")) + assert(ct3.collect().length === 4) + val ct4 = df.stat.crosstab("3", "1") + assert(ct4.schema.fields.length === 5) + assert(ct4.schema.fieldNames.contains("null")) + assert(ct4.schema.fieldNames.contains("a.b")) + assert(ct4.collect().length === 4) + } + test("Frequent Items") { val rows = Seq.tabulate(1000) { i => if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) From 2848f4da47d5c395de93ab9960bd905edfbd3439 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 3 Jul 2015 00:25:02 -0700 Subject: [PATCH 126/274] [SPARK-8809][SQL] Remove ConvertNaNs analyzer rule. "NaN" from string to double is already handled by Cast expression itself. Author: Reynold Xin Closes #7206 from rxin/convertnans and squashes the following commits: 3d99c33 [Reynold Xin] [SPARK-8809][SQL] Remove ConvertNaNs analyzer rule. --- .../catalyst/analysis/HiveTypeCoercion.scala | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6006e7bf00c13..38eb8322c854f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -35,7 +35,6 @@ object HiveTypeCoercion { val typeCoercionRules = PropagateTypes :: - ConvertNaNs :: InConversion :: WidenTypes :: PromoteStrings :: @@ -148,38 +147,6 @@ object HiveTypeCoercion { } } - /** - * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to - * the appropriate numeric equivalent. - */ - // TODO: remove this rule and make Cast handle Nan. - object ConvertNaNs extends Rule[LogicalPlan] { - private val StringNaN = Literal("NaN") - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - - /* Double Conversions */ - case b @ BinaryOperator(StringNaN, right @ DoubleType()) => - b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryOperator(left @ DoubleType(), StringNaN) => - b.makeCopy(Array(left, Literal(Double.NaN))) - - /* Float Conversions */ - case b @ BinaryOperator(StringNaN, right @ FloatType()) => - b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryOperator(left @ FloatType(), StringNaN) => - b.makeCopy(Array(left, Literal(Float.NaN))) - - /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryOperator(left @ StringNaN, StringNaN) => - b.makeCopy(Array(left, Literal(Float.NaN))) - } - } - } - /** * Widens numeric types and converts strings to numbers when appropriate. * From ab535b9a1dab40ea7335ff9abb9b522fc2b5ed66 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Fri, 3 Jul 2015 15:39:16 -0700 Subject: [PATCH 127/274] [SPARK-8226] [SQL] Add function shiftrightunsigned Author: zhichao.li Closes #7035 from zhichao-li/shiftRightUnsigned and squashes the following commits: 6bcca5a [zhichao.li] change coding style 3e9f5ae [zhichao.li] python style d85ae0b [zhichao.li] add shiftrightunsigned --- python/pyspark/sql/functions.py | 13 +++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 49 +++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 13 +++++ .../org/apache/spark/sql/functions.scala | 20 ++++++++ .../spark/sql/MathExpressionsSuite.scala | 17 +++++++ 6 files changed, 113 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 12263e6a75af8..69e563ef36e87 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -436,6 +436,19 @@ def shiftRight(col, numBits): return Column(jc) +@since(1.5) +def shiftRightUnsigned(col, numBits): + """Unsigned shift the the given value numBits right. + + >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\ + .collect() + [Row(r=9223372036854775787)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9163b032adee4..cd5ba1217ccc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -129,6 +129,7 @@ object FunctionRegistry { expression[Rint]("rint"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), + expression[ShiftRightUnsigned]("shiftrightunsigned"), expression[Signum]("sign"), expression[Signum]("signum"), expression[Sin]("sin"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 273a6c5016577..0fc320fb08876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -521,6 +521,55 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress } } +case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression { + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess + case (_, IntegerType) => left.dataType match { + case LongType | IntegerType | ShortType | ByteType => + return TypeCheckResult.TypeCheckSuccess + case _ => // failed + } + case _ => // failed + } + TypeCheckResult.TypeCheckFailure( + s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " + + s"integer value as second argument, not (${left.dataType}, ${right.dataType})") + } + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: Long => l >>> valueRight.asInstanceOf[Integer] + case i: Integer => i >>> valueRight.asInstanceOf[Integer] + case s: Short => s >>> valueRight.asInstanceOf[Integer] + case b: Byte => b >>> valueRight.asInstanceOf[Integer] + } + } else { + null + } + } else { + null + } + } + + override def dataType: DataType = { + left.dataType match { + case LongType => LongType + case IntegerType | ShortType | ByteType => IntegerType + case _ => NullType + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;") + } +} + /** * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 8457864d1782d..20839c83d4fd0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -264,6 +264,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) } + test("shift right unsigned") { + checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21) + checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21) + checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21) + checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) + + checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L) + } + test("hex") { checkEvaluation(Hex(Literal(28)), "1C") checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d5d49c3dd1d7..4b70dc5fdde8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1343,6 +1343,26 @@ object functions { */ def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(columnName: String, numBits: Int): Column = + shiftRightUnsigned(Column(columnName), numBits) + + /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(e: Column, numBits: Int): Column = + ShiftRightUnsigned(e.expr, lit(numBits).expr) + /** * Shift the the given value numBits right. If the given value is a long value, it will return * a long value else it will return an integer value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index dc8f994adbd39..24bef21b999ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest { Row(21.toLong, 21, 21.toShort, 21.toByte, null)) } + test("shift right unsigned") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1), + shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)", + "shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + } + test("binary log") { val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") checkAnswer( From f0fac2aa80da7c739b88043571e5d49ba40f9413 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 3 Jul 2015 15:49:32 -0700 Subject: [PATCH 128/274] [SPARK-7401] [MLLIB] [PYSPARK] Vectorize dot product and sq_dist between SparseVector and DenseVector Currently we iterate over indices which can be vectorized. Author: MechCoder Closes #5946 from MechCoder/spark-7203 and squashes the following commits: 034d086 [MechCoder] Vectorize dot calculation for numpy arrays for ndim=2 bce2b07 [MechCoder] fix doctest fcad0a3 [MechCoder] Remove type checks for list, pyarray etc 0ee5dd4 [MechCoder] Add tests and other isinstance changes e5f1de0 [MechCoder] [SPARK-7401] Vectorize dot product and sq_dist --- python/pyspark/mllib/linalg.py | 44 ++++++++++++++++------------------ python/pyspark/mllib/tests.py | 8 +++++++ 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index e96c5ef87df86..9959a01cce7e0 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -577,22 +577,19 @@ def dot(self, other): ... AssertionError: dimension mismatch """ - if type(other) == np.ndarray: - if other.ndim == 2: - results = [self.dot(other[:, i]) for i in xrange(other.shape[1])] - return np.array(results) - elif other.ndim > 2: + + if isinstance(other, np.ndarray): + if other.ndim not in [2, 1]: raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim) + assert len(self) == other.shape[0], "dimension mismatch" + return np.dot(self.values, other[self.indices]) assert len(self) == _vector_size(other), "dimension mismatch" - if type(other) in (np.ndarray, array.array, DenseVector): - result = 0.0 - for i in xrange(len(self.indices)): - result += self.values[i] * other[self.indices[i]] - return result + if isinstance(other, DenseVector): + return np.dot(other.array[self.indices], self.values) - elif type(other) is SparseVector: + elif isinstance(other, SparseVector): result = 0.0 i, j = 0, 0 while i < len(self.indices) and j < len(other.indices): @@ -635,22 +632,23 @@ def squared_distance(self, other): AssertionError: dimension mismatch """ assert len(self) == _vector_size(other), "dimension mismatch" - if type(other) in (list, array.array, DenseVector, np.array, np.ndarray): - if type(other) is np.array and other.ndim != 1: + + if isinstance(other, np.ndarray) or isinstance(other, DenseVector): + if isinstance(other, np.ndarray) and other.ndim != 1: raise Exception("Cannot call squared_distance with %d-dimensional array" % other.ndim) - result = 0.0 - j = 0 # index into our own array - for i in xrange(len(other)): - if j < len(self.indices) and self.indices[j] == i: - diff = self.values[j] - other[i] - result += diff * diff - j += 1 - else: - result += other[i] * other[i] + if isinstance(other, DenseVector): + other = other.array + sparse_ind = np.zeros(other.size, dtype=bool) + sparse_ind[self.indices] = True + dist = other[sparse_ind] - self.values + result = np.dot(dist, dist) + + other_ind = other[~sparse_ind] + result += np.dot(other_ind, other_ind) return result - elif type(other) is SparseVector: + elif isinstance(other, SparseVector): result = 0.0 i, j = 0, 0 while i < len(self.indices) and j < len(other.indices): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 49ce125de7e78..d9f9874d50c1a 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -129,17 +129,22 @@ def test_dot(self): [1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) self.assertEquals(10.0, sv.dot(dv)) self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) self.assertEquals(30.0, dv.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) self.assertEquals(30.0, lst.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEquals(7.0, sv.dot(arr)) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) dv = DenseVector(array([1., 2., 3., 4.])) lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) self.assertEquals(15.0, _squared_distance(sv, dv)) self.assertEquals(25.0, _squared_distance(sv, lst)) self.assertEquals(20.0, _squared_distance(dv, lst)) @@ -149,6 +154,9 @@ def test_squared_distance(self): self.assertEquals(0.0, _squared_distance(sv, sv)) self.assertEquals(0.0, _squared_distance(dv, dv)) self.assertEquals(0.0, _squared_distance(lst, lst)) + self.assertEquals(25.0, _squared_distance(sv, lst1)) + self.assertEquals(3.0, _squared_distance(sv, arr)) + self.assertEquals(3.0, _squared_distance(sv, narr)) def test_conversion(self): # numpy arrays should be automatically upcast to float64 From e92c24d37cae54634e7af20cbfe313d023786f87 Mon Sep 17 00:00:00 2001 From: Spiro Michaylov Date: Fri, 3 Jul 2015 20:15:58 -0700 Subject: [PATCH 129/274] [SPARK-8810] [SQL] Added several UDF unit tests for Spark SQL One test for each of the GROUP BY, WHERE and HAVING clauses, and one that combines all three with an additional UDF in the SELECT. (Since this is my first attempt at contributing to SPARK, meta-level guidance on anything I've screwed up would be greatly appreciated, whether important or minor.) Author: Spiro Michaylov Closes #7207 from spirom/udf-test-branch and squashes the following commits: 6bbba9e [Spiro Michaylov] Responded to review comments on UDF unit tests 1a3c5ff [Spiro Michaylov] Added several UDF unit tests for Spark SQL --- .../scala/org/apache/spark/sql/UDFSuite.scala | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 703a34c47ec20..8e5da3ac14da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -82,6 +82,76 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } + test("UDF in a WHERE") { + ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + + val df = ctx.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("integerData") + + val result = + ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } + + test("UDF in a HAVING") { + ctx.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + ctx.sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } + + test("UDF in a GROUP BY") { + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + ctx.sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } + + test("UDFs everywhere") { + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) + ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) + ctx.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + ctx.sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } + test("struct UDF") { ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) From 4a22bce8fce30f86f364467a8ba51d2e744ff379 Mon Sep 17 00:00:00 2001 From: Cheolsoo Park Date: Fri, 3 Jul 2015 22:14:21 -0700 Subject: [PATCH 130/274] [SPARK-8572] [SQL] Type coercion for ScalaUDFs Implemented type coercion for udf arguments in Scala. The changes include- * Add `with ExpectsInputTypes ` to `ScalaUDF` class. * Pass down argument types info from `UDFRegistration` and `functions`. With this patch, the example query in [SPARK-8572](https://issues.apache.org/jira/browse/SPARK-8572) no longer throws a type cast error at runtime. Also added a unit test to `UDFSuite` in which a decimal type is passed to a udf that expects an int. Author: Cheolsoo Park Closes #7203 from piaozhexiu/SPARK-8572 and squashes the following commits: 2d0ed15 [Cheolsoo Park] Incorporate comments dce1efd [Cheolsoo Park] Fix unit tests and update the codegen script 066deed [Cheolsoo Park] Type coercion for udf inputs --- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../apache/spark/sql/UDFRegistration.scala | 75 ++++++++++++------- .../spark/sql/UserDefinedFunction.scala | 7 +- .../org/apache/spark/sql/functions.scala | 38 +++++++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 6 ++ 6 files changed, 93 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 38eb8322c854f..84acc0e7e90ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -680,7 +680,7 @@ object HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes => + case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. implicitCast(in, expected).getOrElse(in) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index caf021b016a41..fc055c97a179f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,8 +24,11 @@ import org.apache.spark.sql.types.DataType * User-defined function. * @param dataType Return type of function. */ -case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expression]) - extends Expression { +case class ScalaUDF( + function: AnyRef, + dataType: DataType, + children: Seq[Expression], + inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes { override def nullable: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 03dc37aa73f0c..d35d37d017198 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.{List => JList, Map => JMap} import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import org.apache.spark.{Accumulator, Logging} import org.apache.spark.api.python.PythonBroadcast @@ -30,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType - /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. * @@ -87,6 +87,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { (0 to 22).map { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" /** * Register a Scala closure of ${x} arguments as user-defined function (UDF). @@ -95,7 +96,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try($inputTypes).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) }""") @@ -126,7 +128,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -138,7 +141,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -150,7 +154,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -162,7 +167,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -174,7 +180,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -186,7 +193,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -198,7 +206,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -210,7 +219,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -222,7 +232,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -234,7 +245,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -246,7 +258,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -258,7 +271,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -270,7 +284,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -282,7 +297,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -294,7 +310,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -306,7 +323,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -318,7 +336,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -330,7 +349,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -342,7 +362,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -354,7 +375,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -366,7 +388,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -378,7 +401,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -390,7 +414,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 831eb7eb0fae9..b14e00ab9b163 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -41,10 +41,13 @@ import org.apache.spark.sql.types.DataType * @since 1.3.0 */ @Experimental -case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) { +case class UserDefinedFunction protected[sql] ( + f: AnyRef, + dataType: DataType, + inputTypes: Seq[DataType] = Nil) { def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr))) + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4b70dc5fdde8d..d261baf920c0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection @@ -1584,6 +1585,7 @@ object functions { (0 to 10).map { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" /** * Defines a user-defined function of ${x} arguments as user-defined function (UDF). @@ -1593,7 +1595,8 @@ object functions { * @since 1.3.0 */ def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try($inputTypes).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) }""") } @@ -1625,7 +1628,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1636,7 +1640,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1647,7 +1652,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1658,7 +1664,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1669,7 +1676,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1680,7 +1688,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1691,7 +1700,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1702,7 +1712,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1713,7 +1724,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1724,7 +1736,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1735,7 +1748,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } ////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 8e5da3ac14da6..c1516b450cbd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -166,4 +166,10 @@ class UDFSuite extends QueryTest { // 1 + 1 is constant folded causing a transformation. assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } + + test("type coercion for udf inputs") { + ctx.udf.register("intExpected", (x: Int) => x) + // pass a decimal to intExpected. + assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + } } From 9fb6b832bcc2556aa9db2981106cbd09f2959031 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 3 Jul 2015 22:19:43 -0700 Subject: [PATCH 131/274] [SPARK-8192] [SPARK-8193] [SQL] udf current_date, current_timestamp Author: Daoyuan Wang Closes #6985 from adrian-wang/udfcurrent and squashes the following commits: 6a20b64 [Daoyuan Wang] remove codegen and add lazy in testsuite 27c9f95 [Daoyuan Wang] refine tests.. e11ae75 [Daoyuan Wang] refine tests 61ed3d5 [Daoyuan Wang] add in functions 98e8550 [Daoyuan Wang] fix sytle 427d9dc [Daoyuan Wang] add tests and codegen 0b69a1f [Daoyuan Wang] udf current --- .../catalyst/analysis/FunctionRegistry.scala | 6 ++- .../expressions/datetimeFunctions.scala | 52 +++++++++++++++++++ .../expressions/DatetimeFunctionsSuite.scala | 37 +++++++++++++ .../org/apache/spark/sql/functions.scala | 17 ++++++ .../spark/sql/DatetimeExpressionsSuite.scala | 48 +++++++++++++++++ 5 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cd5ba1217ccc0..a1299aed555c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -163,7 +163,11 @@ object FunctionRegistry { expression[Substring]("substring"), expression[Upper]("ucase"), expression[UnHex]("unhex"), - expression[Upper]("upper") + expression[Upper]("upper"), + + // datetime functions + expression[CurrentDate]("current_date"), + expression[CurrentTimestamp]("current_timestamp") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala new file mode 100644 index 0000000000000..13ba2f2e5d62d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +/** + * Returns the current date at the start of query evaluation. + * All calls of current_date within the same query return the same value. + */ +case class CurrentDate() extends LeafExpression { + override def foldable: Boolean = true + override def nullable: Boolean = false + + override def dataType: DataType = DateType + + override def eval(input: InternalRow): Any = { + DateTimeUtils.millisToDays(System.currentTimeMillis()) + } +} + +/** + * Returns the current timestamp at the start of query evaluation. + * All calls of current_timestamp within the same query return the same value. + */ +case class CurrentTimestamp() extends LeafExpression { + override def foldable: Boolean = true + override def nullable: Boolean = false + + override def dataType: DataType = TimestampType + + override def eval(input: InternalRow): Any = { + System.currentTimeMillis() * 10000L + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala new file mode 100644 index 0000000000000..1618c24871c60 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +class DatetimeFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + test("datetime function current_date") { + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int] + val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) + } + + test("datetime function current_timestamp") { + val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long]) + val t1 = System.currentTimeMillis() + assert(math.abs(t1 - ct.getTime) < 5000) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d261baf920c0c..25e37ff67aa00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -35,6 +35,7 @@ import org.apache.spark.util.Utils * * @groupname udf_funcs UDF functions * @groupname agg_funcs Aggregate functions + * @groupname datetime_funcs Date time functions * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions @@ -991,6 +992,22 @@ object functions { */ def cosh(columnName: String): Column = cosh(Column(columnName)) + /** + * Returns the current date. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def current_date(): Column = CurrentDate() + + /** + * Returns the current timestamp. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def current_timestamp(): Column = CurrentTimestamp() + /** * Computes the exponential of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala new file mode 100644 index 0000000000000..44b915304533c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.functions._ + +class DatetimeExpressionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + + import ctx.implicits._ + + lazy val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + + test("function current_date") { + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) + val d2 = DateTimeUtils.fromJavaDate( + ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) + } + + test("function current_timestamp") { + checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value + checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + Row(true)) + assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + 0).getTime - System.currentTimeMillis()) < 5000) + } + +} From f32487b7ca86f768336a7c9b173f7c610fcde86f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 3 Jul 2015 23:05:17 -0700 Subject: [PATCH 132/274] [SPARK-8777] [SQL] Add random data generator test utilities to Spark SQL This commit adds a set of random data generation utilities to Spark SQL, for use in its own unit tests. - `RandomDataGenerator.forType(DataType)` returns an `Option[() => Any]` that, if defined, contains a function for generating random values for the given DataType. The random values use the external representations for the given DataType (for example, for DateType we return `java.sql.Date` instances instead of longs). - `DateTypeTestUtilities` defines some convenience fields for looping over instances of data types. For example, `numericTypes` holds `DataType` instances for all supported numeric types. These constants will help us to raise the level of abstraction in our tests. For example, it's now very easy to write a test which is parameterized by all common data types. Author: Josh Rosen Closes #7176 from JoshRosen/sql-random-data-generators and squashes the following commits: f71634d [Josh Rosen] Roll back ScalaCheck usage e0d7d49 [Josh Rosen] Bump ScalaCheck version in LICENSE 89d86b1 [Josh Rosen] Bump ScalaCheck version. 0c20905 [Josh Rosen] Initial attempt at using ScalaCheck. b55875a [Josh Rosen] Generate doubles and floats over entire possible range. 5acdd5c [Josh Rosen] Infinity and NaN are interesting. ab76cbd [Josh Rosen] Move code to Catalyst package. d2b4a4a [Josh Rosen] Add random data generator test utilities to Spark SQL. --- .../spark/sql/RandomDataGenerator.scala | 158 ++++++++++++++++++ .../spark/sql/RandomDataGeneratorSuite.scala | 98 +++++++++++ .../spark/sql/types/DataTypeTestUtils.scala | 63 +++++++ 3 files changed, 319 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala new file mode 100644 index 0000000000000..13aad467fa578 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.Double.longBitsToDouble +import java.lang.Float.intBitsToFloat +import java.math.MathContext + +import scala.util.Random + +import org.apache.spark.sql.types._ + +/** + * Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random + * values; instead, they're biased to return "interesting" values (such as maximum / minimum values) + * with higher probability. + */ +object RandomDataGenerator { + + /** + * The conditional probability of a non-null value being drawn from a set of "interesting" values + * instead of being chosen uniformly at random. + */ + private val PROBABILITY_OF_INTERESTING_VALUE: Float = 0.5f + + /** + * The probability of the generated value being null + */ + private val PROBABILITY_OF_NULL: Float = 0.1f + + private val MAX_STR_LEN: Int = 1024 + private val MAX_ARR_SIZE: Int = 128 + private val MAX_MAP_SIZE: Int = 128 + + /** + * Helper function for constructing a biased random number generator which returns "interesting" + * values with a higher probability. + */ + private def randomNumeric[T]( + rand: Random, + uniformRand: Random => T, + interestingValues: Seq[T]): Some[() => T] = { + val f = () => { + if (rand.nextFloat() <= PROBABILITY_OF_INTERESTING_VALUE) { + interestingValues(rand.nextInt(interestingValues.length)) + } else { + uniformRand(rand) + } + } + Some(f) + } + + /** + * Returns a function which generates random values for the given [[DataType]], or `None` if no + * random data generator is defined for that data type. The generated values will use an external + * representation of the data type; for example, the random generator for [[DateType]] will return + * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a + * [[org.apache.spark.Row]]. + * + * @param dataType the type to generate values for + * @param nullable whether null values should be generated + * @param seed an optional seed for the random number generator + * @return a function which can be called to generate random values. + */ + def forType( + dataType: DataType, + nullable: Boolean = true, + seed: Option[Long] = None): Option[() => Any] = { + val rand = new Random() + seed.foreach(rand.setSeed) + + val valueGenerator: Option[() => Any] = dataType match { + case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN))) + case BinaryType => Some(() => { + val arr = new Array[Byte](rand.nextInt(MAX_STR_LEN)) + rand.nextBytes(arr) + arr + }) + case BooleanType => Some(() => rand.nextBoolean()) + case DateType => Some(() => new java.sql.Date(rand.nextInt())) + case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) + case DecimalType.Unlimited => Some( + () => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED)) + case DoubleType => randomNumeric[Double]( + rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, + Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) + case FloatType => randomNumeric[Float]( + rand, r => intBitsToFloat(r.nextInt()), Seq(Float.MinValue, Float.MinPositiveValue, + Float.MaxValue, Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN, 0.0f)) + case ByteType => randomNumeric[Byte]( + rand, _.nextInt().toByte, Seq(Byte.MinValue, Byte.MaxValue, 0.toByte)) + case IntegerType => randomNumeric[Int]( + rand, _.nextInt(), Seq(Int.MinValue, Int.MaxValue, 0)) + case LongType => randomNumeric[Long]( + rand, _.nextLong(), Seq(Long.MinValue, Long.MaxValue, 0L)) + case ShortType => randomNumeric[Short]( + rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) + case NullType => Some(() => null) + case ArrayType(elementType, containsNull) => { + forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { + elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) + } + } + case MapType(keyType, valueType, valueContainsNull) => { + for ( + keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong())); + valueGenerator <- + forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong())) + ) yield { + () => { + Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap + } + } + } + case StructType(fields) => { + val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => + forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong())) + } + if (maybeFieldGenerators.forall(_.isDefined)) { + val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get) + Some(() => Row.fromSeq(fieldGenerators.map(_.apply()))) + } else { + None + } + } + case unsupportedType => None + } + // Handle nullability by wrapping the non-null value generator: + valueGenerator.map { valueGenerator => + if (nullable) { + () => { + if (rand.nextFloat() <= PROBABILITY_OF_NULL) { + null + } else { + valueGenerator() + } + } + } else { + valueGenerator + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala new file mode 100644 index 0000000000000..dbba93dba668e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.types._ + +/** + * Tests of [[RandomDataGenerator]]. + */ +class RandomDataGeneratorSuite extends SparkFunSuite { + + /** + * Tests random data generation for the given type by using it to generate random values then + * converting those values into their Catalyst equivalents using CatalystTypeConverters. + */ + def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) + val generator = RandomDataGenerator.forType(dataType, nullable).getOrElse { + fail(s"Random data generator was not defined for $dataType") + } + if (nullable) { + assert(Iterator.fill(100)(generator()).contains(null)) + } else { + assert(Iterator.fill(100)(generator()).forall(_ != null)) + } + for (_ <- 1 to 10) { + val generatedValue = generator() + toCatalyst(generatedValue) + } + } + + // Basic types: + for ( + dataType <- DataTypeTestUtils.atomicTypes; + nullable <- Seq(true, false) + if !dataType.isInstanceOf[DecimalType] || + dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty + ) { + test(s"$dataType (nullable=$nullable)") { + testRandomDataGeneration(dataType) + } + } + + for ( + arrayType <- DataTypeTestUtils.atomicArrayTypes + if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined + ) { + test(s"$arrayType") { + testRandomDataGeneration(arrayType) + } + } + + val atomicTypesWithDataGenerators = + DataTypeTestUtils.atomicTypes.filter(RandomDataGenerator.forType(_).isDefined) + + // Complex types: + for ( + keyType <- atomicTypesWithDataGenerators; + valueType <- atomicTypesWithDataGenerators + // Scala's BigDecimal.hashCode can lead to OutOfMemoryError on Scala 2.10 (see SI-6173) and + // Spark can hit NumberFormatException errors when converting certain BigDecimals (SPARK-8802). + // For these reasons, we don't support generation of maps with decimal keys. + if !keyType.isInstanceOf[DecimalType] + ) { + val mapType = MapType(keyType, valueType) + test(s"$mapType") { + testRandomDataGeneration(mapType) + } + } + + for ( + colOneType <- atomicTypesWithDataGenerators; + colTwoType <- atomicTypesWithDataGenerators + ) { + val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) + test(s"$structType") { + testRandomDataGeneration(structType) + } + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala new file mode 100644 index 0000000000000..32632b5d6e342 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +/** + * Utility functions for working with DataTypes in tests. + */ +object DataTypeTestUtils { + + /** + * Instances of all [[IntegralType]]s. + */ + val integralType: Set[IntegralType] = Set( + ByteType, ShortType, IntegerType, LongType + ) + + /** + * Instances of all [[FractionalType]]s, including both fixed- and unlimited-precision + * decimal types. + */ + val fractionalTypes: Set[FractionalType] = Set( + DecimalType(precisionInfo = None), + DecimalType(2, 1), + DoubleType, + FloatType + ) + + /** + * Instances of all [[NumericType]]s. + */ + val numericTypes: Set[NumericType] = integralType ++ fractionalTypes + + /** + * Instances of all [[AtomicType]]s. + */ + val atomicTypes: Set[DataType] = numericTypes ++ Set( + BinaryType, + BooleanType, + DateType, + StringType, + TimestampType + ) + + /** + * Instances of [[ArrayType]] for all [[AtomicType]]s. Arrays of these types may contain null. + */ + val atomicArrayTypes: Set[ArrayType] = atomicTypes.map(ArrayType(_, containsNull = true)) +} From f35b0c3436898f22860d2c6c1d12f3a661005201 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 3 Jul 2015 23:45:21 -0700 Subject: [PATCH 133/274] [SPARK-8238][SPARK-8239][SPARK-8242][SPARK-8243][SPARK-8268][SQL]Add ascii/base64/unbase64/encode/decode functions Add `ascii`,`base64`,`unbase64`,`encode` and `decode` expressions. Author: Cheng Hao Closes #6843 from chenghao-intel/str_funcs2 and squashes the following commits: 78dee7d [Cheng Hao] base 64 -> base64 9d6f9f4 [Cheng Hao] remove the toString method for expressions ed5c19c [Cheng Hao] update code as comments 96170fc [Cheng Hao] scalastyle issues e2df768 [Cheng Hao] remove the unused import 491ce7b [Cheng Hao] add ascii/base64/unbase64/encode/decode functions --- .../catalyst/analysis/FunctionRegistry.scala | 5 + .../expressions/stringOperations.scala | 117 ++++++++++++++++++ .../expressions/StringFunctionsSuite.scala | 60 ++++++++- .../org/apache/spark/sql/functions.scala | 93 ++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 38 ++++++ 5 files changed, 308 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a1299aed555c1..e249b58927cc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -156,11 +156,16 @@ object FunctionRegistry { expression[Sum]("sum"), // string functions + expression[Ascii]("ascii"), + expression[Base64]("base64"), + expression[Encode]("encode"), + expression[Decode]("decode"), expression[Lower]("lcase"), expression[Lower]("lower"), expression[StringLength]("length"), expression[Substring]("substr"), expression[Substring]("substring"), + expression[UnBase64]("unbase64"), expression[Upper]("ucase"), expression[UnHex]("unhex"), expression[Upper]("upper"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 57918b32f8a47..154ac3508c0c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -298,3 +298,120 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI override def prettyName: String = "length" } + +/** + * Returns the numeric value of the first character of str. + */ +case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType) + + override def eval(input: InternalRow): Any = { + val string = child.eval(input) + if (string == null) { + null + } else { + val bytes = string.asInstanceOf[UTF8String].getBytes + if (bytes.length > 0) { + bytes(0).asInstanceOf[Int] + } else { + 0 + } + } + } +} + +/** + * Converts the argument from binary to a base 64 string. + */ +case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val bytes = child.eval(input) + if (bytes == null) { + null + } else { + UTF8String.fromBytes( + org.apache.commons.codec.binary.Base64.encodeBase64( + bytes.asInstanceOf[Array[Byte]])) + } + } +} + +/** + * Converts the argument from a base 64 string to BINARY. + */ +case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = BinaryType + override def inputTypes: Seq[DataType] = Seq(StringType) + + override def eval(input: InternalRow): Any = { + val string = child.eval(input) + if (string == null) { + null + } else { + org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) + } + } +} + +/** + * Decodes the first argument into a String using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. (As of Hive 0.12.0.). + */ +case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes { + override def children: Seq[Expression] = bin :: charset :: Nil + override def foldable: Boolean = bin.foldable && charset.foldable + override def nullable: Boolean = bin.nullable || charset.nullable + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType) + + override def eval(input: InternalRow): Any = { + val l = bin.eval(input) + if (l == null) { + null + } else { + val r = charset.eval(input) + if (r == null) { + null + } else { + val fromCharset = r.asInstanceOf[UTF8String].toString + UTF8String.fromString(new String(l.asInstanceOf[Array[Byte]], fromCharset)) + } + } + } +} + +/** + * Encodes the first argument into a BINARY using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. (As of Hive 0.12.0.) +*/ +case class Encode(value: Expression, charset: Expression) + extends Expression with ExpectsInputTypes { + override def children: Seq[Expression] = value :: charset :: Nil + override def foldable: Boolean = value.foldable && charset.foldable + override def nullable: Boolean = value.nullable || charset.nullable + override def dataType: DataType = BinaryType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def eval(input: InternalRow): Any = { + val l = value.eval(input) + if (l == null) { + null + } else { + val r = charset.eval(input) + if (r == null) { + null + } else { + val toCharset = r.asInstanceOf[UTF8String].toString + l.asInstanceOf[UTF8String].toString.getBytes(toCharset) + } + } + } +} + + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 5dbb1d562c1d9..468df20442d38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -217,11 +217,61 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("length for string") { - val regEx = 'a.string.at(0) + val a = 'a.string.at(0) checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) - checkEvaluation(StringLength(regEx), 5, create_row("abdef")) - checkEvaluation(StringLength(regEx), 0, create_row("")) - checkEvaluation(StringLength(regEx), null, create_row(null)) + checkEvaluation(StringLength(a), 5, create_row("abdef")) + checkEvaluation(StringLength(a), 0, create_row("")) + checkEvaluation(StringLength(a), null, create_row(null)) checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) } + + test("ascii for string") { + val a = 'a.string.at(0) + checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) + checkEvaluation(Ascii(a), 97, create_row("abdef")) + checkEvaluation(Ascii(a), 0, create_row("")) + checkEvaluation(Ascii(a), null, create_row(null)) + checkEvaluation(Ascii(Literal.create(null, StringType)), null, create_row("abdef")) + } + + test("base64/unbase64 for string") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + val bytes = Array[Byte](1, 2, 3, 4) + + checkEvaluation(Base64(Literal(bytes)), "AQIDBA==", create_row("abdef")) + checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", create_row("abdef")) + checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef")) + checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null, create_row("abdef")) + checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA==")) + + checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes)) + checkEvaluation(Base64(b), "", create_row(Array[Byte]())) + checkEvaluation(Base64(b), null, create_row(null)) + checkEvaluation(Base64(Literal.create(null, StringType)), null, create_row("abdef")) + + checkEvaluation(UnBase64(a), null, create_row(null)) + checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef")) + } + + test("encode/decode for string") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + checkEvaluation( + Decode(Encode(Literal("大千世界"), Literal("UTF-16LE")), Literal("UTF-16LE")), "大千世界") + checkEvaluation( + Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "大千世界", create_row("大千世界")) + checkEvaluation( + Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "", create_row("")) + // scalastyle:on + checkEvaluation(Encode(a, Literal("utf-8")), null, create_row(null)) + checkEvaluation(Encode(Literal.create(null, StringType), Literal("utf-8")), null) + checkEvaluation(Encode(a, Literal.create(null, StringType)), null, create_row("")) + + checkEvaluation(Decode(b, Literal("utf-8")), null, create_row(null)) + checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null) + checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 25e37ff67aa00..b63c6ee8aba4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1581,6 +1581,7 @@ object functions { /** * Computes the length of a given string value + * * @group string_funcs * @since 1.5.0 */ @@ -1588,11 +1589,103 @@ object functions { /** * Computes the length of a given string column + * * @group string_funcs * @since 1.5.0 */ def strlen(columnName: String): Column = strlen(Column(columnName)) + /** + * Computes the numeric value of the first character of the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def ascii(e: Column): Column = Ascii(e.expr) + + /** + * Computes the numeric value of the first character of the specified string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def ascii(columnName: String): Column = ascii(Column(columnName)) + + /** + * Computes the specified value from binary to a base64 string. + * + * @group string_funcs + * @since 1.5.0 + */ + def base64(e: Column): Column = Base64(e.expr) + + /** + * Computes the specified column from binary to a base64 string. + * + * @group string_funcs + * @since 1.5.0 + */ + def base64(columnName: String): Column = base64(Column(columnName)) + + /** + * Computes the specified value from a base64 string to binary. + * + * @group string_funcs + * @since 1.5.0 + */ + def unbase64(e: Column): Column = UnBase64(e.expr) + + /** + * Computes the specified column from a base64 string to binary. + * + * @group string_funcs + * @since 1.5.0 + */ + def unbase64(columnName: String): Column = unbase64(Column(columnName)) + + /** + * Computes the first argument into a binary from a string using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr) + + /** + * Computes the first argument into a binary from a string using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def encode(columnName: String, charsetColumnName: String): Column = + encode(Column(columnName), Column(charsetColumnName)) + + /** + * Computes the first argument into a string from a binary using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr) + + /** + * Computes the first argument into a string from a binary using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def decode(columnName: String, charsetColumnName: String): Column = + decode(Column(columnName), Column(charsetColumnName)) + + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0d43aca877f68..bd9fa400e5b34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -225,4 +225,42 @@ class DataFrameFunctionsSuite extends QueryTest { Row(l) }) } + + test("string ascii function") { + val df = Seq(("abc", "")).toDF("a", "b") + checkAnswer( + df.select(ascii($"a"), ascii("b")), + Row(97, 0)) + + checkAnswer( + df.selectExpr("ascii(a)", "ascii(b)"), + Row(97, 0)) + } + + test("string base64/unbase64 function") { + val bytes = Array[Byte](1, 2, 3, 4) + val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") + checkAnswer( + df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), + Row("AQIDBA==", "AQIDBA==", bytes, bytes)) + + checkAnswer( + df.selectExpr("base64(a)", "unbase64(b)"), + Row("AQIDBA==", bytes)) + } + + test("string encode/decode function") { + val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") + checkAnswer( + df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")), + Row(bytes, bytes, "大千世界", "大千世界")) + + checkAnswer( + df.selectExpr("encode(a, b)", "decode(c, b)"), + Row(bytes, "大千世界")) + // scalastyle:on + } } From 6b3574e68704d58ba41efe0ea4fe928cc166afcd Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Sat, 4 Jul 2015 01:10:52 -0700 Subject: [PATCH 134/274] [SPARK-8270][SQL] levenshtein distance Jira: https://issues.apache.org/jira/browse/SPARK-8270 Info: I can not build the latest master, it stucks during the build process: `[INFO] Dependency-reduced POM written at: /Users/tarek/test/spark/bagel/dependency-reduced-pom.xml` Author: Tarek Auel Closes #7214 from tarekauel/SPARK-8270 and squashes the following commits: ab348b9 [Tarek Auel] Merge branch 'master' into SPARK-8270 a2ad318 [Tarek Auel] [SPARK-8270] changed order of fields d91b12c [Tarek Auel] [SPARK-8270] python fix adbd075 [Tarek Auel] [SPARK-8270] fixed typo 23185c9 [Tarek Auel] [SPARK-8270] levenshtein distance --- python/pyspark/sql/functions.py | 14 ++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/stringOperations.scala | 32 +++++++++++++++++++ .../expressions/StringFunctionsSuite.scala | 9 ++++++ .../org/apache/spark/sql/functions.scala | 23 ++++++++++--- .../spark/sql/DataFrameFunctionsSuite.scala | 6 ++++ 6 files changed, 81 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 69e563ef36e87..49dd0332afe74 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -323,6 +323,20 @@ def explode(col): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def levenshtein(left, right): + """Computes the Levenshtein distance of the two given strings. + + >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) + >>> df0.select(levenshtein('l', 'r').alias('d')).collect() + [Row(d=3)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def md5(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e249b58927cc4..92a50e7092317 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -163,6 +163,7 @@ object FunctionRegistry { expression[Lower]("lcase"), expression[Lower]("lower"), expression[StringLength]("length"), + expression[Levenshtein]("levenshtein"), expression[Substring]("substr"), expression[Substring]("substring"), expression[UnBase64]("unbase64"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 154ac3508c0c5..6de40629ff27e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern +import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -299,6 +300,37 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI override def prettyName: String = "length" } +/** + * A function that return the Levenshtein distance between the two given strings. + */ +case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression + with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override def dataType: DataType = IntegerType + + override def eval(input: InternalRow): Any = { + val leftValue = left.eval(input) + if (leftValue == null) { + null + } else { + val rightValue = right.eval(input) + if(rightValue == null) { + null + } else { + StringUtils.getLevenshteinDistance(leftValue.toString, rightValue.toString) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val stringUtils = classOf[StringUtils].getName + nullSafeCodeGen(ctx, ev, (res, left, right) => + s"$res = $stringUtils.getLevenshteinDistance($left.toString(), $right.toString());") + } +} + /** * Returns the numeric value of the first character of str. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 468df20442d38..1efbe1a245e83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -274,4 +274,13 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null) checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null)) } + + test("Levenshtein distance") { + checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null) + checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null) + checkEvaluation(Levenshtein(Literal(""), Literal("")), 0) + checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0) + checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3) + checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b63c6ee8aba4b..e4109da08e0a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1580,21 +1580,36 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string value - * + * Computes the length of a given string value. + * * @group string_funcs * @since 1.5.0 */ def strlen(e: Column): Column = StringLength(e.expr) /** - * Computes the length of a given string column - * + * Computes the length of a given string column. + * * @group string_funcs * @since 1.5.0 */ def strlen(columnName: String): Column = strlen(Column(columnName)) + /** + * Computes the Levenshtein distance of the two given strings. + * @group string_funcs + * @since 1.5.0 + */ + def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) + + /** + * Computes the Levenshtein distance of the two given strings. + * @group string_funcs + * @since 1.5.0 + */ + def levenshtein(leftColumnName: String, rightColumnName: String): Column = + levenshtein(Column(leftColumnName), Column(rightColumnName)) + /** * Computes the numeric value of the first character of the specified string value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bd9fa400e5b34..bc455a922d154 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -226,6 +226,12 @@ class DataFrameFunctionsSuite extends QueryTest { }) } + test("Levenshtein distance") { + val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") + checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) + checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) + } + test("string ascii function") { val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( From 48f7aed686afde70a6f0802c6cb37b0cad0509f1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 4 Jul 2015 01:11:35 -0700 Subject: [PATCH 135/274] Fixed minor style issue with the previous merge. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e4109da08e0a4..abcfc0b65020c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1581,7 +1581,7 @@ object functions { /** * Computes the length of a given string value. - * + * * @group string_funcs * @since 1.5.0 */ @@ -1589,7 +1589,7 @@ object functions { /** * Computes the length of a given string column. - * + * * @group string_funcs * @since 1.5.0 */ From 347cab85cd924ffd326f3d1367b3b156ee08052d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 4 Jul 2015 11:55:04 -0700 Subject: [PATCH 136/274] [SQL] More unit tests for implicit type cast & add simpleString to AbstractDataType. Author: Reynold Xin Closes #7221 from rxin/implicit-cast-tests and squashes the following commits: 64b13bd [Reynold Xin] Fixed a bug .. 489b732 [Reynold Xin] [SQL] More unit tests for implicit type cast & add simpleString to AbstractDataType. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 ++--- .../spark/sql/types/AbstractDataType.scala | 7 ++++++ .../apache/spark/sql/types/ArrayType.scala | 2 ++ .../apache/spark/sql/types/DecimalType.scala | 2 ++ .../org/apache/spark/sql/types/MapType.scala | 2 ++ .../apache/spark/sql/types/StructType.scala | 2 ++ .../analysis/HiveTypeCoercionSuite.scala | 25 ++++++++++++++++++- 7 files changed, 42 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 583338da57117..476ac2b7cb474 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -40,7 +40,7 @@ trait CheckAnalysis { def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { exprs.flatMap(_.collect { case e: Generator => true - }).length >= 1 + }).nonEmpty } def checkAnalysis(plan: LogicalPlan): Unit = { @@ -85,12 +85,12 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty => + case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK + case e if groupingExprs.exists(_.semanticEquals(e)) => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index e5dc99fb625d8..ffefb0e7837e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -37,6 +37,9 @@ private[sql] abstract class AbstractDataType { * Returns true if this data type is a parent of the `childCandidate`. */ private[sql] def isParentOf(childCandidate: DataType): Boolean + + /** Readable string representation for the type. */ + private[sql] def simpleString: String } @@ -56,6 +59,10 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst private[sql] override def defaultConcreteType: DataType = types.head private[sql] override def isParentOf(childCandidate: DataType): Boolean = false + + private[sql] override def simpleString: String = { + types.map(_.simpleString).mkString("(", " or ", ")") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 8ea6cb14c360e..43413ec761e6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -31,6 +31,8 @@ object ArrayType extends AbstractDataType { private[sql] override def isParentOf(childCandidate: DataType): Boolean = { childCandidate.isInstanceOf[ArrayType] } + + private[sql] override def simpleString: String = "array" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 434fc037aad4f..127b16ff85bed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -90,6 +90,8 @@ object DecimalType extends AbstractDataType { childCandidate.isInstanceOf[DecimalType] } + private[sql] override def simpleString: String = "decimal" + val Unlimited: DecimalType = DecimalType(None) private[sql] object Fixed { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 2b25617ec6655..868dea13d971e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -75,6 +75,8 @@ object MapType extends AbstractDataType { childCandidate.isInstanceOf[MapType] } + private[sql] override def simpleString: String = "map" + /** * Construct a [[MapType]] object with the given key type and value type. * The `valueContainsNull` is true. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 7e77b77e73940..3b17566d54d9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -309,6 +309,8 @@ object StructType extends AbstractDataType { childCandidate.isInstanceOf[StructType] } + private[sql] override def simpleString: String = "struct" + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 60e727c6c7d4d..67d05ab536b7f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { - test("implicit type cast") { + test("eligible implicit type cast") { def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) assert(got.map(_.dataType) == Option(expected), @@ -68,6 +68,29 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) + + shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType) + shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType) + } + + test("ineligible implicit type cast") { + def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { + val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) + assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got") + } + + shouldNotCast(IntegerType, DateType) + shouldNotCast(IntegerType, TimestampType) + shouldNotCast(LongType, DateType) + shouldNotCast(LongType, TimestampType) + shouldNotCast(DecimalType.Unlimited, DateType) + shouldNotCast(DecimalType.Unlimited, TimestampType) + + shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) + + shouldNotCast(IntegerType, ArrayType) + shouldNotCast(IntegerType, MapType) + shouldNotCast(IntegerType, StructType) } test("tightest common bound for types") { From c991ef5abbb501933b2a68eea1987cf8d88794a5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 4 Jul 2015 11:55:20 -0700 Subject: [PATCH 137/274] [SPARK-8822][SQL] clean up type checking in math.scala. Author: Reynold Xin Closes #7220 from rxin/SPARK-8822 and squashes the following commits: 0cda076 [Reynold Xin] Test cases. 22d0463 [Reynold Xin] Fixed type precedence. beb2a97 [Reynold Xin] [SPARK-8822][SQL] clean up type checking in math.scala. --- .../spark/sql/catalyst/expressions/math.scala | 260 +++++++----------- .../expressions/MathFunctionsSuite.scala | 31 ++- 2 files changed, 123 insertions(+), 168 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 0fc320fb08876..45b7e4d3405c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.lang.{Long => JLong} -import java.util.Arrays +import java.{lang => jl} -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -206,7 +204,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu if (evalE == null) { null } else { - val input = evalE.asInstanceOf[Integer] + val input = evalE.asInstanceOf[jl.Integer] if (input > 20 || input < 0) { null } else { @@ -290,7 +288,7 @@ case class Bin(child: Expression) if (evalE == null) { null } else { - UTF8String.fromString(JLong.toBinaryString(evalE.asInstanceOf[Long])) + UTF8String.fromString(jl.Long.toBinaryString(evalE.asInstanceOf[Long])) } } @@ -300,27 +298,18 @@ case class Bin(child: Expression) } } - /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) extends UnaryExpression with Serializable { +case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes { + // TODO: Create code-gen version. - override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, StringType, BinaryType)) - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[StringType] - || child.dataType.isInstanceOf[IntegerType] - || child.dataType.isInstanceOf[LongType] - || child.dataType.isInstanceOf[BinaryType] - || child.dataType == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"hex doesn't accepts ${child.dataType} type") - } - } + override def dataType: DataType = StringType override def eval(input: InternalRow): Any = { val num = child.eval(input) @@ -329,7 +318,6 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { } else { child.dataType match { case LongType => hex(num.asInstanceOf[Long]) - case IntegerType => hex(num.asInstanceOf[Integer].toLong) case BinaryType => hex(num.asInstanceOf[Array[Byte]]) case StringType => hex(num.asInstanceOf[UTF8String]) } @@ -371,7 +359,55 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte numBuf >>>= 4 } while (numBuf != 0) - UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) + UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) + } +} + + +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes { + // TODO: Create code-gen version. + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def dataType: DataType = BinaryType + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + unhex(num.asInstanceOf[UTF8String].getBytes) + } + } + + private val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } + + private def unhex(inputBytes: Array[Byte]): Array[Byte] = { + var bytes = inputBytes + if ((bytes.length & 0x01) != 0) { + bytes = '0'.toByte +: bytes + } + val out = new Array[Byte](bytes.length >> 1) + // two characters form the hex value. + var i = 0 + while (i < bytes.length) { + val first = unhexDigits(bytes(i)) + val second = unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { return null} + out(i / 2) = (((first << 4) | second) & 0xFF).toByte + i += 2 + } + out } } @@ -423,22 +459,19 @@ case class Pow(left: Expression, right: Expression) } } -case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression { - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess - case (_, IntegerType) => left.dataType match { - case LongType | IntegerType | ShortType | ByteType => - return TypeCheckResult.TypeCheckSuccess - case _ => // failed - } - case _ => // failed - } - TypeCheckResult.TypeCheckFailure( - s"ShiftLeft expects long, integer, short or byte value as first argument and an " + - s"integer value as second argument, not (${left.dataType}, ${right.dataType})") - } +/** + * Bitwise unsigned left shift. + * @param left the base number to shift. + * @param right number of bits to left shift. + */ +case class ShiftLeft(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType override def eval(input: InternalRow): Any = { val valueLeft = left.eval(input) @@ -446,10 +479,8 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi val valueRight = right.eval(input) if (valueRight != null) { valueLeft match { - case l: Long => l << valueRight.asInstanceOf[Integer] - case i: Integer => i << valueRight.asInstanceOf[Integer] - case s: Short => s << valueRight.asInstanceOf[Integer] - case b: Byte => b << valueRight.asInstanceOf[Integer] + case l: jl.Long => l << valueRight.asInstanceOf[jl.Integer] + case i: jl.Integer => i << valueRight.asInstanceOf[jl.Integer] } } else { null @@ -459,35 +490,24 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi } } - override def dataType: DataType = { - left.dataType match { - case LongType => LongType - case IntegerType | ShortType | ByteType => IntegerType - case _ => NullType - } - } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;") } } -case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression { - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess - case (_, IntegerType) => left.dataType match { - case LongType | IntegerType | ShortType | ByteType => - return TypeCheckResult.TypeCheckSuccess - case _ => // failed - } - case _ => // failed - } - TypeCheckResult.TypeCheckFailure( - s"ShiftRight expects long, integer, short or byte value as first argument and an " + - s"integer value as second argument, not (${left.dataType}, ${right.dataType})") - } +/** + * Bitwise unsigned left shift. + * @param left the base number to shift. + * @param right number of bits to left shift. + */ +case class ShiftRight(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType override def eval(input: InternalRow): Any = { val valueLeft = left.eval(input) @@ -495,10 +515,8 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress val valueRight = right.eval(input) if (valueRight != null) { valueLeft match { - case l: Long => l >> valueRight.asInstanceOf[Integer] - case i: Integer => i >> valueRight.asInstanceOf[Integer] - case s: Short => s >> valueRight.asInstanceOf[Integer] - case b: Byte => b >> valueRight.asInstanceOf[Integer] + case l: jl.Long => l >> valueRight.asInstanceOf[jl.Integer] + case i: jl.Integer => i >> valueRight.asInstanceOf[jl.Integer] } } else { null @@ -508,35 +526,24 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress } } - override def dataType: DataType = { - left.dataType match { - case LongType => LongType - case IntegerType | ShortType | ByteType => IntegerType - case _ => NullType - } - } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;") } } -case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression { - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess - case (_, IntegerType) => left.dataType match { - case LongType | IntegerType | ShortType | ByteType => - return TypeCheckResult.TypeCheckSuccess - case _ => // failed - } - case _ => // failed - } - TypeCheckResult.TypeCheckFailure( - s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " + - s"integer value as second argument, not (${left.dataType}, ${right.dataType})") - } +/** + * Bitwise unsigned right shift, for integer and long data type. + * @param left the base number. + * @param right the number of bits to right shift. + */ +case class ShiftRightUnsigned(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType override def eval(input: InternalRow): Any = { val valueLeft = left.eval(input) @@ -544,10 +551,8 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar val valueRight = right.eval(input) if (valueRight != null) { valueLeft match { - case l: Long => l >>> valueRight.asInstanceOf[Integer] - case i: Integer => i >>> valueRight.asInstanceOf[Integer] - case s: Short => s >>> valueRight.asInstanceOf[Integer] - case b: Byte => b >>> valueRight.asInstanceOf[Integer] + case l: jl.Long => l >>> valueRight.asInstanceOf[jl.Integer] + case i: jl.Integer => i >>> valueRight.asInstanceOf[jl.Integer] } } else { null @@ -557,74 +562,21 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar } } - override def dataType: DataType = { - left.dataType match { - case LongType => LongType - case IntegerType | ShortType | ByteType => IntegerType - case _ => NullType - } - } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;") } } -/** - * Performs the inverse operation of HEX. - * Resulting characters are returned as a byte array. - */ -case class UnHex(child: Expression) extends UnaryExpression with Serializable { - - override def dataType: DataType = BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}") - } - } - - override def eval(input: InternalRow): Any = { - val num = child.eval(input) - if (num == null) { - null - } else { - unhex(num.asInstanceOf[UTF8String].getBytes) - } - } - - private val unhexDigits = { - val array = Array.fill[Byte](128)(-1) - (0 to 9).foreach(i => array('0' + i) = i.toByte) - (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) - (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) - array - } - - private def unhex(inputBytes: Array[Byte]): Array[Byte] = { - var bytes = inputBytes - if ((bytes.length & 0x01) != 0) { - bytes = '0'.toByte +: bytes - } - val out = new Array[Byte](bytes.length >> 1) - // two characters form the hex value. - var i = 0 - while (i < bytes.length) { - val first = unhexDigits(bytes(i)) - val second = unhexDigits(bytes(i + 1)) - if (first == -1 || second == -1) { return null} - out(i / 2) = (((first << 4) | second) & 0xFF).toByte - i += 2 - } - out - } -} case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") + +/** + * Computes the logarithm of a number. + * @param left the logarithm base, default to e. + * @param right the number to compute the logarithm of. + */ case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { @@ -642,7 +594,7 @@ case class Logarithm(left: Expression, right: Expression) defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") } logCode + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { + if (Double.isNaN(${ev.primitive})) { ${ev.isNull} = true; } """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 20839c83d4fd0..03d8400cf356b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -161,11 +161,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("factorial") { - val dataLong = (0 to 20) - dataLong.foreach { value => + (0 to 20).foreach { value => checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow) } - checkEvaluation((Literal.create(null, IntegerType)), null, create_row(null)) + checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) checkEvaluation(Factorial(Literal(21)), null, EmptyRow) } @@ -244,10 +243,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42) - checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42) - checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42) - checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) + checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong) } @@ -257,10 +254,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21) - checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21) - checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21) - checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) + checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) } @@ -270,16 +265,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21) - checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21) - checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21) - checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) + checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L) } test("hex") { - checkEvaluation(Hex(Literal(28)), "1C") - checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") @@ -313,6 +304,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow) } + + // null input should yield null output checkEvaluation( Logarithm(Literal.create(null, DoubleType), Literal(1.0)), null, @@ -321,5 +314,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Logarithm(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) + + // negative input should yield null output + checkEvaluation( + Logarithm(Literal(-1.0), Literal(1.0)), + null, + create_row(null)) + checkEvaluation( + Logarithm(Literal(1.0), Literal(-1.0)), + null, + create_row(null)) } } From 2b820f2a4bf9b154762e7516a5b0485322799da9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 4 Jul 2015 22:52:50 -0700 Subject: [PATCH 138/274] [MINOR] [SQL] Minor fix for CatalystSchemaConverter ping liancheng Author: Liang-Chi Hsieh Closes #7224 from viirya/few_fix_catalystschema and squashes the following commits: d994330 [Liang-Chi Hsieh] Minor fix for CatalystSchemaConverter. --- .../main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../spark/sql/parquet/CatalystSchemaConverter.scala | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2c258b6ee399c..6005d35f015a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -273,7 +273,7 @@ private[spark] object SQLConf { val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( key = "spark.sql.parquet.followParquetFormatSpec", defaultValue = Some(false), - doc = "Wether to stick to Parquet format specification when converting Parquet schema to " + + doc = "Whether to stick to Parquet format specification when converting Parquet schema to " + "Spark SQL schema and vice versa. Sticks to the specification if set to true; falls back " + "to compatible mode if set to false.", isPublic = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 2be7c64612cd2..4ab274ec17a02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -142,7 +142,7 @@ private[parquet] class CatalystSchemaConverter( DecimalType(precision, scale) } - field.getPrimitiveTypeName match { + typeName match { case BOOLEAN => BooleanType case FLOAT => FloatType @@ -150,7 +150,7 @@ private[parquet] class CatalystSchemaConverter( case DOUBLE => DoubleType case INT32 => - field.getOriginalType match { + originalType match { case INT_8 => ByteType case INT_16 => ShortType case INT_32 | null => IntegerType @@ -161,7 +161,7 @@ private[parquet] class CatalystSchemaConverter( } case INT64 => - field.getOriginalType match { + originalType match { case INT_64 | null => LongType case DECIMAL => makeDecimalType(maxPrecisionForBytes(8)) case TIMESTAMP_MILLIS => typeNotImplemented() @@ -176,7 +176,7 @@ private[parquet] class CatalystSchemaConverter( TimestampType case BINARY => - field.getOriginalType match { + originalType match { case UTF8 | ENUM => StringType case null if assumeBinaryIsString => StringType case null => BinaryType @@ -185,7 +185,7 @@ private[parquet] class CatalystSchemaConverter( } case FIXED_LEN_BYTE_ARRAY => - field.getOriginalType match { + originalType match { case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength)) case INTERVAL => typeNotImplemented() case _ => illegalType() @@ -261,7 +261,7 @@ private[parquet] class CatalystSchemaConverter( // Here we implement Parquet LIST backwards-compatibility rules. // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules // scalastyle:on - private def isElementType(repeatedType: Type, parentName: String) = { + private def isElementType(repeatedType: Type, parentName: String): Boolean = { { // For legacy 2-level list types with primitive element type, e.g.: // From f9c448dce8139e85ac564daa0f7e0325e778cffe Mon Sep 17 00:00:00 2001 From: Joshi Date: Sun, 5 Jul 2015 12:58:03 -0700 Subject: [PATCH 139/274] [SPARK-7137] [ML] Update SchemaUtils checkInputColumn to print more info if needed Author: Joshi Author: Rekha Joshi Closes #5992 from rekhajoshm/fix/SPARK-7137 and squashes the following commits: 8c42b57 [Joshi] update checkInputColumn to print more info if needed 33ddd2e [Joshi] update checkInputColumn to print more info if needed acf3e17 [Joshi] update checkInputColumn to print more info if needed 8993c0e [Joshi] SPARK-7137: Add checkInputColumn back to Params and print more info e3677c9 [Rekha Joshi] Merge pull request #1 from apache/master --- .../scala/org/apache/spark/ml/util/SchemaUtils.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 7cd53c6d7ef79..76f651488aef9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -32,10 +32,15 @@ private[spark] object SchemaUtils { * @param colName column name * @param dataType required column data type */ - def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + def checkColumnType( + schema: StructType, + colName: String, + dataType: DataType, + msg: String = ""): Unit = { val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.") + s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } /** From a0cb111b22cb093e86b0daeecb3dcc41d095df40 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Sun, 5 Jul 2015 20:50:02 -0700 Subject: [PATCH 140/274] [SPARK-8549] [SPARKR] Fix the line length of SparkR [[SPARK-8549] Fix the line length of SparkR - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8549) Author: Yu ISHIKAWA Closes #7204 from yu-iskw/SPARK-8549 and squashes the following commits: 6fb131a [Yu ISHIKAWA] Fix the typo 1737598 [Yu ISHIKAWA] [SPARK-8549][SparkR] Fix the line length of SparkR --- R/pkg/R/generics.R | 3 ++- R/pkg/R/pairRDD.R | 12 ++++++------ R/pkg/R/sparkR.R | 9 ++++++--- R/pkg/R/utils.R | 31 +++++++++++++++++------------- R/pkg/inst/tests/test_includeJAR.R | 4 ++-- R/pkg/inst/tests/test_rdd.R | 12 ++++++++---- R/pkg/inst/tests/test_sparkSQL.R | 11 +++++++++-- 7 files changed, 51 insertions(+), 31 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 79055b7f18558..fad9d71158c51 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -20,7 +20,8 @@ # @rdname aggregateRDD # @seealso reduce # @export -setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) +setGeneric("aggregateRDD", + function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) # @rdname cache-methods # @export diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 7f902ba8e683e..0f1179e0aa51a 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -560,8 +560,8 @@ setMethod("join", # Left outer join two RDDs # # @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -597,8 +597,8 @@ setMethod("leftOuterJoin", # Right outer join two RDDs # # @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -634,8 +634,8 @@ setMethod("rightOuterJoin", # Full outer join two RDDs # # @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 86233e01db365..048eb8ed541e4 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -105,7 +105,8 @@ sparkR.init <- function( sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { - cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + cat(paste("Re-using existing Spark Context.", + "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } @@ -180,14 +181,16 @@ sparkR.init <- function( sparkExecutorEnvMap <- new.env() if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { - sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- + paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } for (varname in names(sparkExecutorEnv)) { sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + localJarPaths <- sapply(nonEmptyJars, + function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs # Seconds resolution is good enough for this purpose, so use ints diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 13cec0f712fb4..ea629a64f7158 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -334,18 +334,21 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY", "MEMORY_ONLY_SER_2", "OFF_HEAP")) { match.arg(newLevel) + storageLevelClass <- "org.apache.spark.storage.StorageLevel" storageLevel <- switch(newLevel, - "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), - "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), - "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), - "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), - "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), - "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), - "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), - "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), - "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), - "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), - "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) + "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) } # Utility function for functions where an argument needs to be integer but we want to allow @@ -545,9 +548,11 @@ mergePartitions <- function(rdd, zip) { lengthOfKeys <- part[[len - lengthOfValues]] stopifnot(len == lengthOfKeys + lengthOfValues) - # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + # For zip operation, check if corresponding partitions + # of both RDDs have the same number of elements. if (zip && lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + stop(paste("Can only zip RDDs with same number of elements", + "in each pair of corresponding partitions.")) } if (lengthOfKeys > 1) { diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R index 844d86f3cc97f..cc1faeabffe30 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -18,8 +18,8 @@ context("include an external JAR in SparkContext") runScript <- function() { sparkHome <- Sys.getenv("SPARK_HOME") - jarPath <- paste("--jars", - shQuote(file.path(sparkHome, "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar"))) + sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" + jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") submitPath <- file.path(sparkHome, "bin/spark-submit") res <- system2(command = submitPath, diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index fc3c01d837de4..b79692873cec3 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -669,13 +669,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), + list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), + list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -683,13 +685,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), + list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), + list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) test_that("sortByKey() on pairwise RDDs", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 0e4235ea8b4b3..b0ea38854304e 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -391,7 +391,7 @@ test_that("collect() and take() on a DataFrame return the same number of rows an expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) -test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { +test_that("multiple pipeline transformations result in an RDD with the correct values", { df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -756,7 +756,14 @@ test_that("toJSON() returns an RDD of the correct values", { test_that("showDF()", { df <- jsonFile(sqlContext, jsonPath) s <- capture.output(showDF(df)) - expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep="") + expect_output(s , expected) }) test_that("isLocal()", { From 6d0411b4f3a202cfb53f638ee5fd49072b42d3a6 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 5 Jul 2015 21:50:52 -0700 Subject: [PATCH 141/274] [SQL][Minor] Update the DataFrame API for encode/decode This is a the follow up of #6843. Author: Cheng Hao Closes #7230 from chenghao-intel/str_funcs2_followup and squashes the following commits: 52cc553 [Cheng Hao] update the code as comment --- .../expressions/stringOperations.scala | 21 ++++++++++--------- .../org/apache/spark/sql/functions.scala | 14 +++++++------ .../spark/sql/DataFrameFunctionsSuite.scala | 8 +++++-- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 6de40629ff27e..1a14a7a449342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -392,12 +392,13 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput /** * Decodes the first argument into a String using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. (As of Hive 0.12.0.). + * If either argument is null, the result will also be null. */ -case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes { - override def children: Seq[Expression] = bin :: charset :: Nil - override def foldable: Boolean = bin.foldable && charset.foldable - override def nullable: Boolean = bin.nullable || charset.nullable +case class Decode(bin: Expression, charset: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = bin + override def right: Expression = charset override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType) @@ -420,13 +421,13 @@ case class Decode(bin: Expression, charset: Expression) extends Expression with /** * Encodes the first argument into a BINARY using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. (As of Hive 0.12.0.) + * If either argument is null, the result will also be null. */ case class Encode(value: Expression, charset: Expression) - extends Expression with ExpectsInputTypes { - override def children: Seq[Expression] = value :: charset :: Nil - override def foldable: Boolean = value.foldable && charset.foldable - override def nullable: Boolean = value.nullable || charset.nullable + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = value + override def right: Expression = charset override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType, StringType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index abcfc0b65020c..f80291776f335 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1666,18 +1666,19 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr) + def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) /** * Computes the first argument into a binary from a string using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. + * NOTE: charset represents the string value of the character set, not the column name. * * @group string_funcs * @since 1.5.0 */ - def encode(columnName: String, charsetColumnName: String): Column = - encode(Column(columnName), Column(charsetColumnName)) + def encode(columnName: String, charset: String): Column = + encode(Column(columnName), charset) /** * Computes the first argument into a string from a binary using the provided character set @@ -1687,18 +1688,19 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr) + def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) /** * Computes the first argument into a string from a binary using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. + * NOTE: charset represents the string value of the character set, not the column name. * * @group string_funcs * @since 1.5.0 */ - def decode(columnName: String, charsetColumnName: String): Column = - decode(Column(columnName), Column(charsetColumnName)) + def decode(columnName: String, charset: String): Column = + decode(Column(columnName), charset) ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bc455a922d154..afba28515e032 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -261,11 +261,15 @@ class DataFrameFunctionsSuite extends QueryTest { // non ascii characters are not allowed in the code, so we disable the scalastyle here. val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") checkAnswer( - df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")), + df.select( + encode($"a", "utf-8"), + encode("a", "utf-8"), + decode($"c", "utf-8"), + decode("c", "utf-8")), Row(bytes, bytes, "大千世界", "大千世界")) checkAnswer( - df.selectExpr("encode(a, b)", "decode(c, b)"), + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), Row(bytes, "大千世界")) // scalastyle:on } From 86768b7b3b0c2964e744bc491bc20a1d3140ce93 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 5 Jul 2015 23:54:25 -0700 Subject: [PATCH 142/274] [SPARK-8831][SQL] Support AbstractDataType in TypeCollection. Otherwise it is impossible to declare an expression supporting DecimalType. Author: Reynold Xin Closes #7232 from rxin/typecollection-adt and squashes the following commits: 934d3d1 [Reynold Xin] [SPARK-8831][SQL] Support AbstractDataType in TypeCollection. --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 2 -- .../org/apache/spark/sql/types/AbstractDataType.scala | 10 ++++++---- .../sql/catalyst/analysis/HiveTypeCoercionSuite.scala | 6 ++++++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 84acc0e7e90ec..5367b7f3308ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -708,8 +708,6 @@ object HiveTypeCoercion { case (NullType, target) => Cast(e, target.defaultConcreteType) // Implicit cast among numeric types - // If input is decimal, and we expect a decimal type, just use the input. - case (_: DecimalType, DecimalType) => e // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index ffefb0e7837e9..fb1b47e946214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -53,10 +53,12 @@ private[sql] abstract class AbstractDataType { * * This means that we prefer StringType over BinaryType if it is possible to cast to StringType. */ -private[sql] class TypeCollection(private val types: Seq[DataType]) extends AbstractDataType { +private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) + extends AbstractDataType { + require(types.nonEmpty, s"TypeCollection ($types) cannot be empty") - private[sql] override def defaultConcreteType: DataType = types.head + private[sql] override def defaultConcreteType: DataType = types.head.defaultConcreteType private[sql] override def isParentOf(childCandidate: DataType): Boolean = false @@ -68,9 +70,9 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst private[sql] object TypeCollection { - def apply(types: DataType*): TypeCollection = new TypeCollection(types) + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) - def unapply(typ: AbstractDataType): Option[Seq[DataType]] = typ match { + def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { case typ: TypeCollection => Some(typ.types) case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 67d05ab536b7f..b56426617789e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -71,6 +71,12 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType) shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType) + + shouldCast( + DecimalType.Unlimited, TypeCollection(IntegerType, DecimalType), DecimalType.Unlimited) + shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) + shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) + shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) } test("ineligible implicit type cast") { From 39e4e7e4d89077a637c4cad3a986e0e3447d1ae7 Mon Sep 17 00:00:00 2001 From: Steve Lindemann Date: Mon, 6 Jul 2015 10:17:05 -0700 Subject: [PATCH 143/274] [SPARK-8841] [SQL] Fix partition pruning percentage log message When pruning partitions for a query plan, a message is logged indicating what how many partitions were selected based on predicate criteria, and what percent were pruned. The current release erroneously uses `1 - total/selected` to compute this quantity, leading to nonsense messages like "pruned -1000% partitions". The fix is simple and obvious. Author: Steve Lindemann Closes #7227 from srlindemann/master and squashes the following commits: c788061 [Steve Lindemann] fix percentPruned log message --- .../scala/org/apache/spark/sql/sources/DataSourceStrategy.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index ce16e050c56ed..66f7ba90140b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -65,7 +65,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { logInfo { val total = t.partitionSpec.partitions.length val selected = selectedPartitions.length - val percentPruned = (1 - total.toDouble / selected.toDouble) * 100 + val percentPruned = (1 - selected.toDouble / total.toDouble) * 100 s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." } From 293225e0cd9318ad368dde30ac6a17725d33ebb6 Mon Sep 17 00:00:00 2001 From: "Daniel Emaasit (PhD Student)" Date: Mon, 6 Jul 2015 10:36:02 -0700 Subject: [PATCH 144/274] [SPARK-8124] [SPARKR] Created more examples on SparkR DataFrames Here are more examples on SparkR DataFrames including creating a Spark Contect and a SQL context, loading data and simple data manipulation. Author: Daniel Emaasit (PhD Student) Closes #6668 from Emaasit/dan-dev and squashes the following commits: 3a97867 [Daniel Emaasit (PhD Student)] Used fewer rows for createDataFrame f7227f9 [Daniel Emaasit (PhD Student)] Using command line arguments a550f70 [Daniel Emaasit (PhD Student)] Used base R functions 33f9882 [Daniel Emaasit (PhD Student)] Renamed file b6603e3 [Daniel Emaasit (PhD Student)] changed "Describe" function to "describe" 90565dd [Daniel Emaasit (PhD Student)] Deleted the getting-started file b95a103 [Daniel Emaasit (PhD Student)] Deleted this file cc55cd8 [Daniel Emaasit (PhD Student)] combined all the code into one .R file c6933af [Daniel Emaasit (PhD Student)] changed variable name to SQLContext 8e0fe14 [Daniel Emaasit (PhD Student)] provided two options for creating DataFrames 2653573 [Daniel Emaasit (PhD Student)] Updates to a comment and variable name 275b787 [Daniel Emaasit (PhD Student)] Added the Apache License at the top of the file 2e8f724 [Daniel Emaasit (PhD Student)] Added the Apache License at the top of the file 486f44e [Daniel Emaasit (PhD Student)] Added the Apache License at the file d705112 [Daniel Emaasit (PhD Student)] Created more examples on SparkR DataFrames --- examples/src/main/r/data-manipulation.R | 107 ++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 examples/src/main/r/data-manipulation.R diff --git a/examples/src/main/r/data-manipulation.R b/examples/src/main/r/data-manipulation.R new file mode 100644 index 0000000000000..aa2336e300a91 --- /dev/null +++ b/examples/src/main/r/data-manipulation.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# For this example, we shall use the "flights" dataset +# The dataset consists of every flight departing Houston in 2011. +# The data set is made up of 227,496 rows x 14 columns. + +# To run this example use +# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3 +# examples/src/main/r/data-manipulation.R + +# Load SparkR library into your R session +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: data-manipulation.R % + summarize(avg(flightsDF$dep_delay), avg(flightsDF$arr_delay)) -> dailyDelayDF + + # Print the computed data frame + head(dailyDelayDF) +} + +# Stop the SparkContext now +sparkR.stop() From 0e194645f42be0d6ac9b5a712f8fc1798418736d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 Jul 2015 13:26:46 -0700 Subject: [PATCH 145/274] [SPARK-8837][SPARK-7114][SQL] support using keyword in column name Author: Wenchen Fan Closes #7237 from cloud-fan/parser and squashes the following commits: e7b49bb [Wenchen Fan] support using keyword in column name --- .../apache/spark/sql/catalyst/SqlParser.scala | 28 ++++++++++++------- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 ++++++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 8d02fbf4f92c4..e8e9b9802e94b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -287,15 +287,18 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { throw new AnalysisException(s"invalid function approximate($floatLit) $udfName") } } - | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ - (ELSE ~> expression).? <~ END ^^ { - case casePart ~ altPart ~ elsePart => - val branches = altPart.flatMap { case whenExpr ~ thenExpr => - Seq(whenExpr, thenExpr) - } ++ elsePart - casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches)) - } - ) + | CASE ~> whenThenElse ^^ CaseWhen + | CASE ~> expression ~ whenThenElse ^^ + { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } + ) + + protected lazy val whenThenElse: Parser[List[Expression]] = + rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { + case altPart ~ elsePart => + altPart.flatMap { case whenExpr ~ thenExpr => + Seq(whenExpr, thenExpr) + } ++ elsePart + } protected lazy val cast: Parser[Expression] = CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { @@ -354,6 +357,11 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val signedPrimary: Parser[Expression] = sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e} + protected lazy val attributeName: Parser[String] = acceptMatch("attribute name", { + case lexical.Identifier(str) => str + case lexical.Keyword(str) if !lexical.delimiters.contains(str) => str + }) + protected lazy val primary: PackratParser[Expression] = ( literal | expression ~ ("[" ~> expression <~ "]") ^^ @@ -364,9 +372,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | "(" ~> expression <~ ")" | function | dotExpressionHeader - | ident ^^ {case i => UnresolvedAttribute.quoted(i)} | signedPrimary | "~" ~> expression ^^ BitwiseNot + | attributeName ^^ UnresolvedAttribute.quoted ) protected lazy val dotExpressionHeader: Parser[Expression] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cc6af1ccc1cce..12ad019e8b473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1458,4 +1458,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) } } + + test("SPARK-8837: use keyword in column name") { + withTempTable("t") { + val df = Seq(1 -> "a").toDF("count", "sort") + checkAnswer(df.filter("count > 0"), Row(1, "a")) + df.registerTempTable("t") + checkAnswer(sql("select count, sort from t"), Row(1, "a")) + } + } } From 57c72fcce75907c08a1ae53a0d85447176fc3c69 Mon Sep 17 00:00:00 2001 From: Dirceu Semighini Filho Date: Mon, 6 Jul 2015 13:28:07 -0700 Subject: [PATCH 146/274] Small update in the readme file Just change the attribute from -PsparkR to -Psparkr Author: Dirceu Semighini Filho Closes #7242 from dirceusemighini/patch-1 and squashes the following commits: fad5991 [Dirceu Semighini Filho] Small update in the readme file --- R/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/README.md b/R/README.md index d7d65b4f0eca5..005f56da1670c 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R #### Build Spark -Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ``` build/mvn -DskipTests -Psparkr package ``` From 37e4d92142a6309e2df7d36883e0c7892c3d792d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 6 Jul 2015 13:31:31 -0700 Subject: [PATCH 147/274] [SPARK-8784] [SQL] Add Python API for hex and unhex Add Python API for hex/unhex, also cleanup Hex/Unhex Author: Davies Liu Closes #7223 from davies/hex and squashes the following commits: 6f1249d [Davies Liu] no explicit rule to cast string into binary 711a6ed [Davies Liu] fix test f9fe5a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into hex f032fbb [Davies Liu] Merge branch 'hex' of github.com:davies/spark into hex 49e325f [Davies Liu] Merge branch 'master' of github.com:apache/spark into hex b31fc9a [Davies Liu] Update math.scala 25156b7 [Davies Liu] address comments and fix test c3af78c [Davies Liu] address commments 1a24082 [Davies Liu] Add Python API for hex and unhex --- python/pyspark/sql/functions.py | 28 +++++++ .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 83 ++++++++++--------- .../expressions/MathFunctionsSuite.scala | 25 ++++-- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 93 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 49dd0332afe74..dca39fa833435 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -395,6 +395,34 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def hex(col): + """Computes hex value of the given column, which could be StringType, + BinaryType, IntegerType or LongType. + + >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + [Row(hex(a)=u'414243', hex(b)=u'3')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def unhex(col): + """Inverse of hex. Interprets each pair of characters as a hexadecimal number + and converts to the byte representation of number. + + >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + [Row(unhex(a)=bytearray(b'ABC'))] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.unhex(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def sha1(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 92a50e7092317..fef276353022c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,7 +168,7 @@ object FunctionRegistry { expression[Substring]("substring"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), - expression[UnHex]("unhex"), + expression[Unhex]("unhex"), expression[Upper]("upper"), // datetime functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 45b7e4d3405c8..92500453980f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -298,6 +298,21 @@ case class Bin(child: Expression) } } +object Hex { + val hexDigits = Array[Char]( + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + ).map(_.toByte) + + // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 + val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } +} + /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. * Otherwise if the number is a STRING, it converts each character into its hex representation @@ -307,7 +322,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, StringType, BinaryType)) + Seq(TypeCollection(LongType, BinaryType, StringType)) override def dataType: DataType = StringType @@ -319,30 +334,18 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes child.dataType match { case LongType => hex(num.asInstanceOf[Long]) case BinaryType => hex(num.asInstanceOf[Array[Byte]]) - case StringType => hex(num.asInstanceOf[UTF8String]) + case StringType => hex(num.asInstanceOf[UTF8String].getBytes) } } } - /** - * Converts every character in s to two hex digits. - */ - private def hex(str: UTF8String): UTF8String = { - hex(str.getBytes) - } - - private def hex(bytes: Array[Byte]): UTF8String = { - doHex(bytes, bytes.length) - } - - private def doHex(bytes: Array[Byte], length: Int): UTF8String = { + private[this] def hex(bytes: Array[Byte]): UTF8String = { + val length = bytes.length val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = Character.toUpperCase(Character.forDigit( - (bytes(i) & 0xF0) >>> 4, 16)).toByte - value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( - bytes(i) & 0x0F, 16)).toByte + value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) i += 1 } UTF8String.fromBytes(value) @@ -355,24 +358,23 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes var len = 0 do { len += 1 - value(value.length - len) = - Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte + value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) } } - /** * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ -case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes { // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def nullable: Boolean = true override def dataType: DataType = BinaryType override def eval(input: InternalRow): Any = { @@ -384,26 +386,31 @@ case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTyp } } - private val unhexDigits = { - val array = Array.fill[Byte](128)(-1) - (0 to 9).foreach(i => array('0' + i) = i.toByte) - (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) - (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) - array - } - - private def unhex(inputBytes: Array[Byte]): Array[Byte] = { - var bytes = inputBytes + private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { + val out = new Array[Byte]((bytes.length + 1) >> 1) + var i = 0 if ((bytes.length & 0x01) != 0) { - bytes = '0'.toByte +: bytes + // padding with '0' + if (bytes(0) < 0) { + return null + } + val v = Hex.unhexDigits(bytes(0)) + if (v == -1) { + return null + } + out(0) = v + i += 1 } - val out = new Array[Byte](bytes.length >> 1) // two characters form the hex value. - var i = 0 while (i < bytes.length) { - val first = unhexDigits(bytes(i)) - val second = unhexDigits(bytes(i + 1)) - if (first == -1 || second == -1) { return null} + if (bytes(i) < 0 || bytes(i + 1) < 0) { + return null + } + val first = Hex.unhexDigits(bytes(i)) + val second = Hex.unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { + return null + } out(i / 2) = (((first << 4) | second) & 0xFF).toByte i += 2 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 03d8400cf356b..7ca9e30b2bcd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -21,8 +21,7 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DataType, LongType} -import org.apache.spark.sql.types.{IntegerType, DoubleType} +import org.apache.spark.sql.types._ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -271,20 +270,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("hex") { + checkEvaluation(Hex(Literal.create(null, LongType)), null) + checkEvaluation(Hex(Literal(28L)), "1C") + checkEvaluation(Hex(Literal(-28L)), "FFFFFFFFFFFFFFE4") checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") - checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") + checkEvaluation(Hex(Literal.create(null, BinaryType)), null) checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") // scalastyle:off // Turn off scala style for non-ascii chars - checkEvaluation(Hex(Literal("三重的")), "E4B889E9878DE79A84") + checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") // scalastyle:on } test("unhex") { - checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) - checkEvaluation(UnHex(Literal("")), new Array[Byte](0)) - checkEvaluation(UnHex(Literal("0")), Array[Byte](0)) + checkEvaluation(Unhex(Literal.create(null, StringType)), null) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) + checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) + checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) + checkEvaluation(Unhex(Literal("GG")), null) + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) + checkEvaluation(Unhex(Literal("三重的")), null) + + // scalastyle:on } test("hypot") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f80291776f335..4da9ffc495e17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1095,7 +1095,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = UnHex(column.expr) + def unhex(column: Column): Column = Unhex(column.expr) /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number From 2471c0bf7f463bb144b44a2e51c0f363e71e099d Mon Sep 17 00:00:00 2001 From: kai Date: Mon, 6 Jul 2015 14:33:30 -0700 Subject: [PATCH 148/274] [SPARK-4485] [SQL] 1) Add broadcast hash outer join, (2) Fix SparkPlanTest This pull request (1) extracts common functions used by hash outer joins and put it in interface HashOuterJoin (2) adds ShuffledHashOuterJoin and BroadcastHashOuterJoin (3) adds test cases for shuffled and broadcast hash outer join (3) makes SparkPlanTest to support binary or more complex operators, and fixes bugs in plan composition in SparkPlanTest Author: kai Closes #7162 from kai-zeng/outer and squashes the following commits: 3742359 [kai] Fix not-serializable exception for code-generated keys in broadcasted relations 14e4bf8 [kai] Use CanBroadcast in broadcast outer join planning dc5127e [kai] code style fixes b5a4efa [kai] (1) Add broadcast hash outer join, (2) Fix SparkPlanTest --- .../spark/sql/execution/SparkStrategies.scala | 12 +- .../joins/BroadcastHashOuterJoin.scala | 121 ++++++++++++++++++ .../sql/execution/joins/HashOuterJoin.scala | 95 ++++---------- .../joins/ShuffledHashOuterJoin.scala | 85 ++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 40 +++++- .../spark/sql/execution/SparkPlanTest.scala | 99 +++++++++++--- .../sql/execution/joins/OuterJoinSuite.scala | 88 +++++++++++++ 7 files changed, 441 insertions(+), 99 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5daf86d817586..32044989044a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -117,8 +117,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys( + LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys( + RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - joins.HashOuterJoin( + joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala new file mode 100644 index 0000000000000..5da04c78744d9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.ThreadUtils + +import scala.collection.JavaConversions._ +import scala.concurrent._ +import scala.concurrent.duration._ + +/** + * :: DeveloperApi :: + * Performs a outer hash join for two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + val timeout = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + override def requiredChildDistribution: Seq[Distribution] = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + private[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + private[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + @transient + private val broadcastFuture = future { + // Note that we use .execute().collect() because we don't want to convert data to Scala types + val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() + // buildHashTable uses code-generated rows as keys, which are not serializable + val hashed = + buildHashTable(input.iterator, new InterpretedProjection(buildKeys, buildPlan.output)) + sparkContext.broadcast(hashed) + }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) + + override def doExecute(): RDD[InternalRow] = { + val broadcastRelation = Await.result(broadcastFuture, timeout) + + streamedPlan.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + val keyGenerator = newProjection(streamedKeys, streamedPlan.output) + + joinType match { + case LeftOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + } + } +} + +object BroadcastHashOuterJoin { + + private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index e41538ec1fc1a..886b5fa0c5103 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -19,32 +19,25 @@ package org.apache.spark.sql.execution.joins import java.util.{HashMap => JavaHashMap} -import org.apache.spark.rdd.RDD - -import scala.collection.JavaConversions._ - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ @DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { +trait HashOuterJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val joinType: JoinType + val condition: Option[Expression] + val left: SparkPlan + val right: SparkPlan + +override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) @@ -52,9 +45,6 @@ case class HashOuterJoin( throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } - override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = { joinType match { case LeftOuter => @@ -68,8 +58,8 @@ case class HashOuterJoin( } } - @transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow] + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) + @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @@ -80,7 +70,7 @@ case class HashOuterJoin( // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. - private[this] def leftOuterIterator( + protected[this] def leftOuterIterator( key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { @@ -89,7 +79,7 @@ case class HashOuterJoin( val temp = rightIter.collect { case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withRight(rightNullRow).copy :: Nil } else { temp @@ -101,18 +91,17 @@ case class HashOuterJoin( ret.iterator } - private[this] def rightOuterIterator( + protected[this] def rightOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = leftIter.collect { case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy + joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withLeft(leftNullRow).copy :: Nil } else { temp @@ -124,10 +113,9 @@ case class HashOuterJoin( ret.iterator } - private[this] def fullOuterIterator( + protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { - if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -171,7 +159,7 @@ case class HashOuterJoin( } } - private[this] def buildHashTable( + protected[this] def buildHashTable( iter: Iterator[InternalRow], keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() @@ -190,43 +178,4 @@ case class HashOuterJoin( hashTable } - - protected override def doExecute(): RDD[InternalRow] = { - val joinedRow = new JoinedRow() - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - - joinType match { - case LeftOuter => - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val keyGenerator = newProjection(leftKeys, left.output) - leftIter.flatMap( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) - }) - - case RightOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val keyGenerator = newProjection(rightKeys, right.output) - rightIter.flatMap ( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) - }) - - case FullOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) - } - - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala new file mode 100644 index 0000000000000..cfc9c14aaa363 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +import scala.collection.JavaConversions._ + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class ShuffledHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val joinedRow = new JoinedRow() + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + joinType match { + case LeftOuter => + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val keyGenerator = newProjection(leftKeys, left.output) + leftIter.flatMap( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val keyGenerator = newProjection(rightKeys, right.output) + rightIter.flatMap ( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case FullOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST), + joinedRow) + } + + case x => + throw new IllegalArgumentException( + s"ShuffledHashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 20390a5544304..8953889d1fae9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -45,9 +45,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j - case j: HashOuterJoin => j + case j: ShuffledHashOuterJoin => j case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j + case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j @@ -81,12 +82,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[HashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -133,6 +135,34 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.sql("UNCACHE TABLE testData") } + test("broadcasted hash outer join operator selection") { + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + } + + ctx.sql("UNCACHE TABLE testData") + } + test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 13f3be8ca28d6..108b1122f7bff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -54,6 +54,37 @@ class SparkPlanTest extends SparkFunSuite { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[Row]): Unit = { + checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + checkAnswer(left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -72,11 +103,41 @@ class SparkPlanTest extends SparkFunSuite { planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[A]): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } + checkAnswer(input, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(left, right, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(input, planFunction, expectedRows) } + } /** @@ -92,27 +153,25 @@ object SparkPlanTest { * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ def checkAnswer( - input: DataFrame, - planFunction: SparkPlan => SparkPlan, + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row]): Option[String] = { - val outputPlan = planFunction(input.queryExecution.sparkPlan) + val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = outputPlan transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { - case (a, i) => - (a.name, BoundReference(i, a.dataType, a.nullable)) - }.toMap - - plan.transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } + val resolvedPlan = TestSQLContext.prepareForExecution.execute( + outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + ) def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala new file mode 100644 index 0000000000000..5707d2fb300ae --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + +class OuterJoinSuite extends SparkPlanTest { + + val left = Seq( + (1, 2.0), + (2, 1.0), + (3, 3.0) + ).toDF("a", "b") + + val right = Seq( + (2, 3.0), + (3, 2.0), + (4, 1.0) + ).toDF("c", "d") + + val leftKeys: List[Expression] = 'a :: Nil + val rightKeys: List[Expression] = 'c :: Nil + val condition = Some(LessThan('b, 'd)) + + test("shuffled hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } + + test("broadcast hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } +} From 132e7fca129be8f00ba429a51bcef60abb2eed6d Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 6 Jul 2015 15:54:43 -0700 Subject: [PATCH 149/274] [MINOR] [SQL] remove unused code in Exchange Author: Daoyuan Wang Closes #7234 from adrian-wang/exchangeclean and squashes the following commits: b093ec9 [Daoyuan Wang] remove unused code --- .../org/apache/spark/sql/execution/Exchange.scala | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index edc64a03335d6..e054c1d144e34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -117,20 +117,6 @@ case class Exchange( } } - private val keyOrdering = { - if (newOrdering.nonEmpty) { - val key = newPartitioning.keyExpressions - val boundOrdering = newOrdering.map { o => - val ordinal = key.indexOf(o.child) - if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") - o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) - } - new RowOrdering(boundOrdering) - } else { - null // Ordering will not be used - } - } - @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf private def getSerializer( From 9ff203346ca4decf2999e33bfb8c400ec75313e6 Mon Sep 17 00:00:00 2001 From: Wisely Chen Date: Mon, 6 Jul 2015 16:04:01 -0700 Subject: [PATCH 150/274] [SPARK-8656] [WEBUI] Fix the webUI and JSON API number is not synced Spark standalone master web UI show "Alive Workers" total core, total used cores and "Alive workers" total memory, memory used. But the JSON API page "http://MASTERURL:8088/json" shows "ALL workers" core, memory number. This webUI data is not sync with the JSON API. The proper way is to sync the number with webUI and JSON API. Author: Wisely Chen Closes #7038 from thegiive/SPARK-8656 and squashes the following commits: 9e54bf0 [Wisely Chen] Change variable name to camel case 2c8ea89 [Wisely Chen] Change some styling and add local variable 431d2b0 [Wisely Chen] Worker List should contain DEAD node also 8b3b8e8 [Wisely Chen] [SPARK-8656] Fix the webUI and JSON API number is not synced --- .../scala/org/apache/spark/deploy/JsonProtocol.scala | 9 +++++---- .../org/apache/spark/deploy/master/WorkerInfo.scala | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 2954f932b4f41..ccffb36652988 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -76,12 +76,13 @@ private[deploy] object JsonProtocol { } def writeMasterState(obj: MasterStateResponse): JObject = { + val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ + ("cores" -> aliveWorkers.map(_.cores).sum) ~ + ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ + ("memory" -> aliveWorkers.map(_.memory).sum) ~ + ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 471811037e5e2..f751966605206 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -105,4 +105,6 @@ private[spark] class WorkerInfo( def setState(state: WorkerState.Value): Unit = { this.state = state } + + def isAlive(): Boolean = this.state == WorkerState.ALIVE } From 1165b17d24cdf1dbebb2faca14308dfe5c2a652c Mon Sep 17 00:00:00 2001 From: Ankur Chauhan Date: Mon, 6 Jul 2015 16:04:57 -0700 Subject: [PATCH 151/274] [SPARK-6707] [CORE] [MESOS] Mesos Scheduler should allow the user to specify constraints based on slave attributes Currently, the mesos scheduler only looks at the 'cpu' and 'mem' resources when trying to determine the usablility of a resource offer from a mesos slave node. It may be preferable for the user to be able to ensure that the spark jobs are only started on a certain set of nodes (based on attributes). For example, If the user sets a property, let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. Author: Ankur Chauhan Closes #5563 from ankurcha/mesos_attribs and squashes the following commits: 902535b [Ankur Chauhan] Fix line length d83801c [Ankur Chauhan] Update code as per code review comments 8b73f2d [Ankur Chauhan] Fix imports c3523e7 [Ankur Chauhan] Added docs 1a24d0b [Ankur Chauhan] Expand scope of attributes matching to include all data types 482fd71 [Ankur Chauhan] Update access modifier to private[this] for offer constraints 5ccc32d [Ankur Chauhan] Fix nit pick whitespace 1bce782 [Ankur Chauhan] Fix nit pick whitespace c0cbc75 [Ankur Chauhan] Use offer id value for debug message 7fee0ea [Ankur Chauhan] Add debug statements fc7eb5b [Ankur Chauhan] Fix import codestyle 00be252 [Ankur Chauhan] Style changes as per code review comments 662535f [Ankur Chauhan] Incorporate code review comments + use SparkFunSuite fdc0937 [Ankur Chauhan] Decline offers that did not meet criteria 67b58a0 [Ankur Chauhan] Add documentation for spark.mesos.constraints 63f53f4 [Ankur Chauhan] Update codestyle - uniform style for config values 02031e4 [Ankur Chauhan] Fix scalastyle warnings in tests c09ed84 [Ankur Chauhan] Fixed the access modifier on offerConstraints val to private[mesos] 0c64df6 [Ankur Chauhan] Rename overhead fractions to memory_*, fix spacing 8cc1e8f [Ankur Chauhan] Make exception message more explicit about the source of the error addedba [Ankur Chauhan] Added test case for malformed constraint string ec9d9a6 [Ankur Chauhan] Add tests for parse constraint string 72fe88a [Ankur Chauhan] Fix up tests + remove redundant method override, combine utility class into new mesos scheduler util trait 92b47fd [Ankur Chauhan] Add attributes based constraints support to MesosScheduler --- .../mesos/CoarseMesosSchedulerBackend.scala | 43 +++-- .../scheduler/cluster/mesos/MemoryUtils.scala | 31 ---- .../cluster/mesos/MesosClusterScheduler.scala | 1 + .../cluster/mesos/MesosSchedulerBackend.scala | 62 ++++--- .../cluster/mesos/MesosSchedulerUtils.scala | 153 +++++++++++++++++- .../cluster/mesos/MemoryUtilsSuite.scala | 46 ------ .../mesos/MesosSchedulerBackendSuite.scala | 6 +- .../mesos/MesosSchedulerUtilsSuite.scala | 140 ++++++++++++++++ docs/running-on-mesos.md | 22 +++ 9 files changed, 376 insertions(+), 128 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 6b8edca5aa485..b68f8c7685eba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,18 +18,18 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{Collections, List => JList} +import java.util.{List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -66,6 +66,10 @@ private[spark] class CoarseMesosSchedulerBackend( val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + // Offer constraints + private val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -170,13 +174,16 @@ private[spark] class CoarseMesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers) { + val offerAttributes = toAttributeMap(offer.getAttributesList) + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) val slaveId = offer.getSlaveId.toString val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && - mem >= MemoryUtils.calculateTotalMemory(sc) && + val id = offer.getId.getValue + if (meetsConstraints && + totalCoresAcquired < maxCores && + mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { @@ -193,33 +200,25 @@ private[spark] class CoarseMesosSchedulerBackend( .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", - MemoryUtils.calculateTotalMemory(sc))) + .addResources(createResource("mem", calculateTotalMemory(sc))) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder()) + .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder) } - d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters) + // accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.launchTasks(List(offer.getId), List(task.build()), filters) } else { - // Filter it out - d.launchTasks( - Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters) + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } } } } - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala deleted file mode 100644 index 8df4f3b554c41..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import org.apache.spark.SparkContext - -private[spark] object MemoryUtils { - // These defaults copied from YARN - val OVERHEAD_FRACTION = 0.10 - val OVERHEAD_MINIMUM = 384 - - def calculateTotalMemory(sc: SparkContext): Int = { - sc.conf.getInt("spark.mesos.executor.memoryOverhead", - math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1067a7f1caf4c..d3a20f822176e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.{Scheduler, SchedulerDriver} + import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 49de85ef48ada..d72e2af456e15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -23,14 +23,14 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -59,6 +59,10 @@ private[spark] class MesosSchedulerBackend( private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) + // Offer constraints + private[this] val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + @volatile var appId: String = _ override def start() { @@ -71,8 +75,8 @@ private[spark] class MesosSchedulerBackend( val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { - throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") - } + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => environment.addVariables( @@ -115,14 +119,14 @@ private[spark] class MesosSchedulerBackend( .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder() - .setValue(mesosExecutorCores).build()) + .setValue(mesosExecutorCores).build()) .build() val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) .setScalar( Value.Scalar.newBuilder() - .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) + .setValue(calculateTotalMemory(sc)).build()) .build() val executorInfo = MesosExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) @@ -191,13 +195,31 @@ private[spark] class MesosSchedulerBackend( val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) + val offerAttributes = toAttributeMap(o.getAttributesList) + + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + val meetsRequirements = + (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + + // add some debug messaging + val debugstr = if (meetsRequirements) "Accepting" else "Declining" + val id = o.getId.getValue + logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + + meetsRequirements } + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) + val workerOffers = usableOffers.map { o => val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt @@ -223,15 +245,15 @@ private[spark] class MesosSchedulerBackend( val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) acceptedOffers .foreach { offer => - offer.foreach { taskDesc => - val slaveId = taskDesc.executorId - slaveIdsWithExecutors += slaveId - slavesIdsOfAcceptedOffers += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(createMesosTask(taskDesc, slaveId)) - } + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slaveIdsWithExecutors += slaveId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(createMesosTask(taskDesc, slaveId)) } + } // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? @@ -251,8 +273,6 @@ private[spark] class MesosSchedulerBackend( d.declineOffer(o.getId) } - // Decline offers we ruled out immediately - unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d11228f3d016a..d8a8c848bb4d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,14 +17,17 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util.List +import java.util.{List => JList} import java.util.concurrent.CountDownLatch import scala.collection.JavaConversions._ +import scala.util.control.NonFatal -import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} -import org.apache.mesos.{MesosSchedulerDriver, Scheduler} -import org.apache.spark.Logging +import com.google.common.base.Splitter +import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler} +import org.apache.mesos.Protos._ +import org.apache.mesos.protobuf.GeneratedMessage +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.util.Utils /** @@ -86,10 +89,150 @@ private[mesos] trait MesosSchedulerUtils extends Logging { /** * Get the amount of resources for the specified type from the resource list */ - protected def getResource(res: List[Resource], name: String): Double = { + protected def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } 0.0 } + + /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */ + protected def getAttribute(attr: Attribute): (String, Set[String]) = { + (attr.getName, attr.getText.getValue.split(',').toSet) + } + + + /** Build a Mesos resource protobuf object */ + protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { + Resource.newBuilder() + .setName(resourceName) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) + .build() + } + + /** + * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * The attribute values are the mesos attribute types and they are + * @param offerAttributes + * @return + */ + protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + offerAttributes.map(attr => { + val attrValue = attr.getType match { + case Value.Type.SCALAR => attr.getScalar + case Value.Type.RANGES => attr.getRanges + case Value.Type.SET => attr.getSet + case Value.Type.TEXT => attr.getText + } + (attr.getName, attrValue) + }).toMap + } + + + /** + * Match the requirements (if any) to the offer attributes. + * if attribute requirements are not specified - return true + * else if attribute is defined and no values are given, simple attribute presence is performed + * else if attribute name and value is specified, subset match is performed on slave attributes + */ + def matchesAttributeRequirements( + slaveOfferConstraints: Map[String, Set[String]], + offerAttributes: Map[String, GeneratedMessage]): Boolean = { + slaveOfferConstraints.forall { + // offer has the required attribute and subsumes the required values for that attribute + case (name, requiredValues) => + offerAttributes.get(name) match { + case None => false + case Some(_) if requiredValues.isEmpty => true // empty value matches presence + case Some(scalarValue: Value.Scalar) => + // check if provided values is less than equal to the offered values + requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue) + case Some(rangeValue: Value.Range) => + val offerRange = rangeValue.getBegin to rangeValue.getEnd + // Check if there is some required value that is between the ranges specified + // Note: We only support the ability to specify discrete values, in the future + // we may expand it to subsume ranges specified with a XX..YY value or something + // similar to that. + requiredValues.map(_.toLong).exists(offerRange.contains(_)) + case Some(offeredValue: Value.Set) => + // check if the specified required values is a subset of offered set + requiredValues.subsetOf(offeredValue.getItemList.toSet) + case Some(textValue: Value.Text) => + // check if the specified value is equal, if multiple values are specified + // we succeed if any of them match. + requiredValues.contains(textValue.getValue) + } + } + } + + /** + * Parses the attributes constraints provided to spark and build a matching data struct: + * Map[, Set[values-to-match]] + * The constraints are specified as ';' separated key-value pairs where keys and values + * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for + * multiple values (comma separated). For example: + * {{{ + * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * // would result in + * + * Map( + * "tachyon" -> Set("true"), + * "zone": -> Set("us-east-1a", "us-east-1b") + * ) + * }}} + * + * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/ + * https://github.com/apache/mesos/blob/master/src/common/values.cpp + * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp + * + * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * by ':') + * @return Map of constraints to match resources offers. + */ + def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = { + /* + Based on mesos docs: + attributes : attribute ( ";" attribute )* + attribute : labelString ":" ( labelString | "," )+ + labelString : [a-zA-Z0-9_/.-] + */ + val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':') + // kv splitter + if (constraintsVal.isEmpty) { + Map() + } else { + try { + Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map { + case (k, v) => + if (v == null || v.isEmpty) { + (k, Set[String]()) + } else { + (k, v.split(',').toSet) + } + } + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) + } + } + } + + // These defaults copied from YARN + private val MEMORY_OVERHEAD_FRACTION = 0.10 + private val MEMORY_OVERHEAD_MINIMUM = 384 + + /** + * Return the amount of memory to allocate to each executor, taking into account + * container overheads. + * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value + * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * (whichever is larger) + */ + def calculateTotalMemory(sc: SparkContext): Int = { + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + + sc.executorMemory + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala deleted file mode 100644 index e72285d03d3ee..0000000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar - -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} - -class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar { - test("MesosMemoryUtils should always override memoryOverhead when it's set") { - val sparkConf = new SparkConf - - val sc = mock[SparkContext] - when(sc.conf).thenReturn(sparkConf) - - // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 - when(sc.executorMemory).thenReturn(512) - assert(MemoryUtils.calculateTotalMemory(sc) === 896) - - // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 - when(sc.executorMemory).thenReturn(4096) - assert(MemoryUtils.calculateTotalMemory(sc) === 4505) - - // set memoryOverhead - sparkConf.set("spark.mesos.executor.memoryOverhead", "100") - assert(MemoryUtils.calculateTotalMemory(sc) === 4196) - sparkConf.set("spark.mesos.executor.memoryOverhead", "400") - assert(MemoryUtils.calculateTotalMemory(sc) === 4496) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 68df46a41ddc8..d01837fe78957 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -149,7 +149,9 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(sc.conf).thenReturn(new SparkConf) when(sc.listenerBus).thenReturn(listenerBus) - val minMem = MemoryUtils.calculateTotalMemory(sc).toInt + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val minMem = backend.calculateTotalMemory(sc) val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] @@ -157,8 +159,6 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi mesosOffers.add(createOffer(2, minMem - 1, minCpu)) mesosOffers.add(createOffer(3, minMem, minCpu)) - val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](2) expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(0).getSlaveId.getValue, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala new file mode 100644 index 0000000000000..b354914b6ffd0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.mesos.Protos.Value +import org.mockito.Mockito._ +import org.scalatest._ +import org.scalatest.mock.MockitoSugar +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { + + // scalastyle:off structural.type + // this is the documented way of generating fixtures in scalatest + def fixture: Object {val sc: SparkContext; val sparkConf: SparkConf} = new { + val sparkConf = new SparkConf + val sc = mock[SparkContext] + when(sc.conf).thenReturn(sparkConf) + } + val utils = new MesosSchedulerUtils { } + // scalastyle:on structural.type + + test("use at-least minimum overhead") { + val f = fixture + when(f.sc.executorMemory).thenReturn(512) + utils.calculateTotalMemory(f.sc) shouldBe 896 + } + + test("use overhead if it is greater than minimum value") { + val f = fixture + when(f.sc.executorMemory).thenReturn(4096) + utils.calculateTotalMemory(f.sc) shouldBe 4505 + } + + test("use spark.mesos.executor.memoryOverhead (if set)") { + val f = fixture + when(f.sc.executorMemory).thenReturn(1024) + f.sparkConf.set("spark.mesos.executor.memoryOverhead", "512") + utils.calculateTotalMemory(f.sc) shouldBe 1536 + } + + test("parse a non-empty constraint string correctly") { + val expectedMap = Map( + "tachyon" -> Set("true"), + "zone" -> Set("us-east-1a", "us-east-1b") + ) + utils.parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") should be (expectedMap) + } + + test("parse an empty constraint string correctly") { + utils.parseConstraintString("") shouldBe Map() + } + + test("throw an exception when the input is malformed") { + an[IllegalArgumentException] should be thrownBy + utils.parseConstraintString("tachyon;zone:us-east") + } + + test("empty values for attributes' constraints matches all values") { + val constraintsStr = "tachyon:" + val parsedConstraints = utils.parseConstraintString(constraintsStr) + + parsedConstraints shouldBe Map("tachyon" -> Set()) + + val zoneSet = Value.Set.newBuilder().addItem("us-east-1a").addItem("us-east-1b").build() + val noTachyonOffer = Map("zone" -> zoneSet) + val tachyonTrueOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + val tachyonFalseOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("false").build()) + + utils.matchesAttributeRequirements(parsedConstraints, noTachyonOffer) shouldBe false + utils.matchesAttributeRequirements(parsedConstraints, tachyonTrueOffer) shouldBe true + utils.matchesAttributeRequirements(parsedConstraints, tachyonFalseOffer) shouldBe true + } + + test("subset match is performed for set attributes") { + val supersetConstraint = Map( + "tachyon" -> Value.Text.newBuilder().setValue("true").build(), + "zone" -> Value.Set.newBuilder() + .addItem("us-east-1a") + .addItem("us-east-1b") + .addItem("us-east-1c") + .build()) + + val zoneConstraintStr = "tachyon:;zone:us-east-1a,us-east-1c" + val parsedConstraints = utils.parseConstraintString(zoneConstraintStr) + + utils.matchesAttributeRequirements(parsedConstraints, supersetConstraint) shouldBe true + } + + test("less than equal match is performed on scalar attributes") { + val offerAttribs = Map("gpus" -> Value.Scalar.newBuilder().setValue(3).build()) + + val ltConstraint = utils.parseConstraintString("gpus:2") + val eqConstraint = utils.parseConstraintString("gpus:3") + val gtConstraint = utils.parseConstraintString("gpus:4") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + } + + test("contains match is performed for range attributes") { + val offerAttribs = Map("ports" -> Value.Range.newBuilder().setBegin(7000).setEnd(8000).build()) + val ltConstraint = utils.parseConstraintString("ports:6000") + val eqConstraint = utils.parseConstraintString("ports:7500") + val gtConstraint = utils.parseConstraintString("ports:8002") + val multiConstraint = utils.parseConstraintString("ports:5000,7500,8300") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(multiConstraint, offerAttribs) shouldBe true + } + + test("equality match is performed for text attributes") { + val offerAttribs = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + + val trueConstraint = utils.parseConstraintString("tachyon:true") + val falseConstraint = utils.parseConstraintString("tachyon:false") + + utils.matchesAttributeRequirements(trueConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(falseConstraint, offerAttribs) shouldBe false + } + +} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5f1d6daeb27f0..1f915d8ea1d73 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -184,6 +184,14 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. + +{% highlight scala %} +conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +{% endhighlight %} + +For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. + # Mesos Docker Support Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` @@ -298,6 +306,20 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.constraints + Attribute based constraints to be matched against when accepting resource offers. + + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. +
    +
  • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
  • +
  • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
  • +
  • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
  • +
  • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
  • +
  • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
  • +
+ + # Troubleshooting and Debugging From 96c5eeec3970e8b1ebc6ddf5c97a7acc47f539dc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 6 Jul 2015 16:11:22 -0700 Subject: [PATCH 152/274] Revert "[SPARK-7212] [MLLIB] Add sequence learning flag" This reverts commit 25f574eb9a3cb9b93b7d9194a8ec16e00ce2c036. After speaking to some users and developers, we realized that FP-growth doesn't meet the requirement for frequent sequence mining. PrefixSpan (SPARK-6487) would be the correct algorithm for it. feynmanliang Author: Xiangrui Meng Closes #7240 from mengxr/SPARK-7212.revert and squashes the following commits: 2b3d66b [Xiangrui Meng] Revert "[SPARK-7212] [MLLIB] Add sequence learning flag" --- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 38 +++----------- .../spark/mllib/fpm/FPGrowthSuite.scala | 52 +------------------ python/pyspark/mllib/fpm.py | 4 +- 3 files changed, 12 insertions(+), 82 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index abac08022ea47..efa8459d3cdba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -36,7 +36,7 @@ import org.apache.spark.storage.StorageLevel * :: Experimental :: * * Model trained by [[FPGrowth]], which holds frequent itemsets. - * @param freqItemsets frequent itemsets, which is an RDD of [[FreqItemset]] + * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] * @tparam Item item type */ @Experimental @@ -62,14 +62,13 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex @Experimental class FPGrowth private ( private var minSupport: Double, - private var numPartitions: Int, - private var ordered: Boolean) extends Logging with Serializable { + private var numPartitions: Int) extends Logging with Serializable { /** * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same - * as the input data, ordered: `false`}. + * as the input data}. */ - def this() = this(0.3, -1, false) + def this() = this(0.3, -1) /** * Sets the minimal support level (default: `0.3`). @@ -87,15 +86,6 @@ class FPGrowth private ( this } - /** - * Indicates whether to mine itemsets (unordered) or sequences (ordered) (default: false, mine - * itemsets). - */ - def setOrdered(ordered: Boolean): this.type = { - this.ordered = ordered - this - } - /** * Computes an FP-Growth model that contains frequent itemsets. * @param data input data set, each element contains a transaction @@ -165,7 +155,7 @@ class FPGrowth private ( .flatMap { case (part, tree) => tree.extract(minCount, x => partitioner.getPartition(x) == part) }.map { case (ranks, count) => - new FreqItemset(ranks.map(i => freqItems(i)).reverse.toArray, count, ordered) + new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) } } @@ -181,12 +171,9 @@ class FPGrowth private ( itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] - // Filter the basket by frequent items pattern + // Filter the basket by frequent items pattern and sort their ranks. val filtered = transaction.flatMap(itemToRank.get) - if (!this.ordered) { - ju.Arrays.sort(filtered) - } - // Generate conditional transactions + ju.Arrays.sort(filtered) val n = filtered.length var i = n - 1 while (i >= 0) { @@ -211,18 +198,9 @@ object FPGrowth { * Frequent itemset. * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. * @param freq frequency - * @param ordered indicates if items represents an itemset (false) or sequence (true) * @tparam Item item type */ - class FreqItemset[Item](val items: Array[Item], val freq: Long, val ordered: Boolean) - extends Serializable { - - /** - * Auxillary constructor, assumes unordered by default. - */ - def this(items: Array[Item], freq: Long) { - this(items, freq, false) - } + class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { /** * Returns items in a Java List. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 1a8a1e79f2810..66ae3543ecc4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { - test("FP-Growth frequent itemsets using String type") { + test("FP-Growth using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -38,14 +38,12 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .setOrdered(false) .run(rdd) val freqItemsets3 = model3.freqItemsets.collect().map { itemset => (itemset.items.toSet, itemset.freq) @@ -63,59 +61,17 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 625) } - test("FP-Growth frequent sequences using String type"){ - val transactions = Seq( - "r z h k p", - "z y x w v u t s", - "s x o n r", - "x z y m t s q e", - "z", - "x z y r q t p") - .map(_.split(" ")) - val rdd = sc.parallelize(transactions, 2).cache() - - val fpg = new FPGrowth() - - val model1 = fpg - .setMinSupport(0.5) - .setNumPartitions(2) - .setOrdered(true) - .run(rdd) - - /* - Use the following R code to verify association rules using arulesSequences package. - - data = read_baskets("path", info = c("sequenceID","eventID","SIZE")) - freqItemSeq = cspade(data, parameter = list(support = 0.5)) - resSeq = as(freqItemSeq, "data.frame") - resSeq$support = resSeq$support * length(transactions) - names(resSeq)[names(resSeq) == "support"] = "freq" - resSeq - */ - val expected = Set( - (Seq("r"), 3L), (Seq("s"), 3L), (Seq("t"), 3L), (Seq("x"), 4L), (Seq("y"), 3L), - (Seq("z"), 5L), (Seq("z", "y"), 3L), (Seq("x", "t"), 3L), (Seq("y", "t"), 3L), - (Seq("z", "t"), 3L), (Seq("z", "y", "t"), 3L) - ) - val freqItemseqs1 = model1.freqItemsets.collect().map { itemset => - (itemset.items.toSeq, itemset.freq) - }.toSet - assert(freqItemseqs1 == expected) - } - - test("FP-Growth frequent itemsets using Int type") { + test("FP-Growth using Int type") { val transactions = Seq( "1 2 3", "1 2 3 4", @@ -132,14 +88,12 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .setOrdered(false) .run(rdd) assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, "frequent itemsets should use primitive arrays") @@ -155,14 +109,12 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 15) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 65) } diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index b7f00d60069e6..bdc4a132b1b18 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -39,8 +39,8 @@ class FPGrowthModel(JavaModelWrapper): >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] >>> rdd = sc.parallelize(data, 2) >>> model = FPGrowth.train(rdd, 0.6, 2) - >>> sorted(model.freqItemsets().collect(), key=lambda x: x.items) - [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'a', u'c'], freq=3), ... + >>> sorted(model.freqItemsets().collect()) + [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... """ def freqItemsets(self): From 0effe180f4c2cf37af1012b33b43912bdecaf756 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 6 Jul 2015 16:15:12 -0700 Subject: [PATCH 153/274] [SPARK-8765] [MLLIB] Fix PySpark PowerIterationClustering test issue PySpark PowerIterationClustering test failure due to bad demo data. If the data is small, PowerIterationClustering will behavior indeterministic. Author: Yanbo Liang Closes #7177 from yanboliang/spark-8765 and squashes the following commits: 392ae54 [Yanbo Liang] fix model.assignments output 5ec3f1e [Yanbo Liang] fix PySpark PowerIterationClustering test issue --- python/pyspark/mllib/clustering.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a3eab635282f6..ed4d78a2c6788 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -282,18 +282,30 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): Model produced by [[PowerIterationClustering]]. - >>> data = [(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0), - ... (0, 3, 1.0), (1, 2, 1.0), (0, 4, 0.1)] + >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0), + ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), + ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0), + ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)] >>> rdd = sc.parallelize(data, 2) >>> model = PowerIterationClustering.train(rdd, 2, 100) >>> model.k 2 + >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) + >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster + True + >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = PowerIterationClusteringModel.load(sc, path) >>> sameModel.k 2 + >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) + >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster + True + >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster + True >>> from shutil import rmtree >>> try: ... rmtree(path) From 7b467cc9348fa910e445ad08914a72f8ed4fc249 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 6 Jul 2015 16:26:31 -0700 Subject: [PATCH 154/274] [SPARK-8588] [SQL] Regression test This PR adds regression test for https://issues.apache.org/jira/browse/SPARK-8588 (fixed by https://github.com/apache/spark/commit/457d07eaa023b44b75344110508f629925eb6247). Author: Yin Huai This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #7103 from yhuai/SPARK-8588-test and squashes the following commits: eb5f418 [Yin Huai] Add a query test. c61a173 [Yin Huai] Regression test for SPARK-8588. --- .../analysis/HiveTypeCoercionSuite.scala | 21 +++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 16 ++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index b56426617789e..93db33d44eb25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -271,4 +271,25 @@ class HiveTypeCoercionSuite extends PlanTest { Literal(true) ) } + + /** + * There are rules that need to not fire before child expressions get resolved. + * We use this test to make sure those rules do not fire early. + */ + test("make sure rules do not fire early") { + // InConversion + val inConversion = HiveTypeCoercion.InConversion + ruleTest(inConversion, + In(UnresolvedAttribute("a"), Seq(Literal(1))), + In(UnresolvedAttribute("a"), Seq(Literal(1))) + ) + ruleTest(inConversion, + In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))), + In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))) + ) + ruleTest(inConversion, + In(Literal("a"), Seq(Literal(1), Literal("b"))), + In(Literal("a"), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) + ) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6d645393a6da1..bf9f2ecd51793 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -990,5 +990,21 @@ class SQLQuerySuite extends QueryTest { Timestamp.valueOf("1969-12-31 16:00:00"), String.valueOf("1969-12-31 16:00:00"), Timestamp.valueOf("1970-01-01 00:00:00"))) + + } + + test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { + val df = + TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) + df.toDF("id", "date").registerTempTable("test_SPARK8588") + checkAnswer( + TestHive.sql( + """ + |select id, concat(year(date)) + |from test_SPARK8588 where concat(year(date), ' year') in ('2015 year', '2014 year') + """.stripMargin), + Row(1, "2014") :: Row(2, "2015") :: Nil + ) + TestHive.dropTempTable("test_SPARK8588") } } From 09a06418debc25da0191d98798f7c5016d39be91 Mon Sep 17 00:00:00 2001 From: animesh Date: Mon, 6 Jul 2015 16:39:49 -0700 Subject: [PATCH 155/274] [SPARK-8072] [SQL] Better AnalysisException for writing DataFrame with identically named columns Adding a function checkConstraints which will check for the constraints to be applied on the dataframe / dataframe schema. Function called before storing the dataframe to an external storage. Function added in the corresponding datasource API. cc rxin marmbrus Author: animesh This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #7013 from animeshbaranawal/8072 and squashes the following commits: f70dd0e [animesh] Change IO exception to Analysis Exception fd45e1b [animesh] 8072: Fix Style Issues a8a964f [animesh] 8072: Improving on previous commits 3cc4d2c [animesh] Fix Style Issues 1a89115 [animesh] Fix Style Issues 98b4399 [animesh] 8072 : Moved the exception handling to ResolvedDataSource specific to parquet format 7c3d928 [animesh] 8072: Adding check to DataFrameWriter.scala --- .../apache/spark/sql/json/JSONRelation.scala | 31 +++++++++++++++++++ .../apache/spark/sql/parquet/newParquet.scala | 19 +++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 24 ++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 69bf13e1e5a6a..2361d3bf52d2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -22,6 +22,7 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -37,6 +38,17 @@ private[sql] class DefaultSource parameters.getOrElse("path", sys.error("'path' must be specified for json data.")) } + /** Constraints to be imposed on dataframe to be stored. */ + private def checkConstraints(data: DataFrame): Unit = { + if (data.schema.fieldNames.length != data.schema.fieldNames.distinct.length) { + val duplicateColumns = data.schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to JSON format") + } + } + /** Returns a new base relation with the parameters. */ override def createRelation( sqlContext: SQLContext, @@ -63,6 +75,10 @@ private[sql] class DefaultSource mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = { + // check if dataframe satisfies the constraints + // before moving forward + checkConstraints(data) + val path = checkPath(parameters) val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) @@ -130,6 +146,17 @@ private[sql] class JSONRelation( samplingRatio, userSpecifiedSchema)(sqlContext) + /** Constraints to be imposed on dataframe to be stored. */ + private def checkConstraints(data: DataFrame): Unit = { + if (data.schema.fieldNames.length != data.schema.fieldNames.distinct.length) { + val duplicateColumns = data.schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to JSON format") + } + } + private val useJacksonStreamingAPI: Boolean = sqlContext.conf.useJacksonStreamingAPI override val needConversion: Boolean = false @@ -178,6 +205,10 @@ private[sql] class JSONRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { + // check if dataframe satisfies constraints + // before moving forward + checkConstraints(data) + val filesystemPath = path match { case Some(p) => new Path(p) case None => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 5ac3e9a44e6fe..6bc69c6ad0847 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -164,7 +164,24 @@ private[sql] class ParquetRelation2( } } - override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema) + /** Constraints on schema of dataframe to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to parquet format") + } + } + + override def dataSchema: StructType = { + val schema = maybeDataSchema.getOrElse(metadataCache.dataSchema) + // check if schema satisfies the constraints + // before moving forward + checkConstraints(schema) + schema + } override private[sql] def refresh(): Unit = { super.refresh() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index afb1cf5f8d1cb..f592a9934d0e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -737,4 +737,28 @@ class DataFrameSuite extends QueryTest { df.col("") df.col("t.``") } + + test("SPARK-8072: Better Exception for Duplicate Columns") { + // only one duplicate column present + val e = intercept[org.apache.spark.sql.AnalysisException] { + val df1 = Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") + .write.format("parquet").save("temp") + } + assert(e.getMessage.contains("Duplicate column(s)")) + assert(e.getMessage.contains("parquet")) + assert(e.getMessage.contains("column1")) + assert(!e.getMessage.contains("column2")) + + // multiple duplicate columns present + val f = intercept[org.apache.spark.sql.AnalysisException] { + val df2 = Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) + .toDF("column1", "column2", "column3", "column1", "column3") + .write.format("json").save("temp") + } + assert(f.getMessage.contains("Duplicate column(s)")) + assert(f.getMessage.contains("JSON")) + assert(f.getMessage.contains("column1")) + assert(f.getMessage.contains("column3")) + assert(!f.getMessage.contains("column2")) + } } From d4d6d31db5cc5c69ac369f754b7489f444c9ba2f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 6 Jul 2015 17:16:44 -0700 Subject: [PATCH 156/274] [SPARK-8463][SQL] Use DriverRegistry to load jdbc driver at writing path JIRA: https://issues.apache.org/jira/browse/SPARK-8463 Currently, at the reading path, `DriverRegistry` is used to load needed jdbc driver at executors. However, at the writing path, we also need `DriverRegistry` to load jdbc driver. Author: Liang-Chi Hsieh Closes #6900 from viirya/jdbc_write_driver and squashes the following commits: 16cd04b [Liang-Chi Hsieh] Use DriverRegistry to load jdbc driver at writing path. --- .../main/scala/org/apache/spark/sql/jdbc/jdbc.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index dd8aaf6474895..f7ea852fe7f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -58,13 +58,12 @@ package object jdbc { * are used. */ def savePartition( - url: String, + getConnection: () => Connection, table: String, iterator: Iterator[Row], rddSchema: StructType, - nullTypes: Array[Int], - properties: Properties): Iterator[Byte] = { - val conn = DriverManager.getConnection(url, properties) + nullTypes: Array[Int]): Iterator[Byte] = { + val conn = getConnection() var committed = false try { conn.setAutoCommit(false) // Everything in the same db transaction. @@ -185,8 +184,10 @@ package object jdbc { } val rddSchema = df.schema + val driver: String = DriverRegistry.getDriverClassName(url) + val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes, properties) + JDBCWriteDetails.savePartition(getConnection, table, iterator, rddSchema, nullTypes) } } From 9eae5fa642317dd11fc783d832d4cbb7e62db471 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 6 Jul 2015 19:22:30 -0700 Subject: [PATCH 157/274] [SPARK-8819] Fix build for maven 3.3.x This is a workaround for MSHADE-148, which leads to an infinite loop when building Spark with maven 3.3.x. This was originally caused by #6441, which added a bunch of test dependencies on the spark-core test module. Recently, it was revealed by #7193. This patch adds a `-Prelease` profile. If present, it will set `createDependencyReducedPom` to true. The consequences are: - If you are releasing Spark with this profile, you are fine as long as you use maven 3.2.x or before. - If you are releasing Spark without this profile, you will run into SPARK-8781. - If you are not releasing Spark but you are using this profile, you may run into SPARK-8819. - If you are not releasing Spark and you did not include this profile, you are fine. This is all documented in `pom.xml` and tested locally with both versions of maven. Author: Andrew Or Closes #7219 from andrewor14/fix-maven-build and squashes the following commits: 1d37e87 [Andrew Or] Merge branch 'master' of github.com:apache/spark into fix-maven-build 3574ae4 [Andrew Or] Review comments f39199c [Andrew Or] Create a -Prelease profile that flags `createDependencyReducedPom` --- dev/create-release/create-release.sh | 4 ++-- pom.xml | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 54274a83f6d66..cfe2cd4752b3f 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,13 +118,13 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO - build/mvn -DskipTests -Pyarn -Phive \ + build/mvn -DskipTests -Pyarn -Phive -Prelease-profile\ -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh - build/mvn -DskipTests -Pyarn -Phive \ + build/mvn -DskipTests -Pyarn -Phive -Prelease-profile\ -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install diff --git a/pom.xml b/pom.xml index ffa96128a3d61..fbcc9152765cf 100644 --- a/pom.xml +++ b/pom.xml @@ -161,6 +161,8 @@ 2.4.4 1.1.1.7 1.1.2 + + false ${java.home} @@ -1440,6 +1442,8 @@ 2.3 false + + ${create.dependency.reduced.pom} @@ -1826,6 +1830,26 @@ + + + release-profile + + + true + + + - release-profile + release false @@ -179,6 +180,8 @@ compile compile compile + test + test + + twttr-repo + Twttr Repository + http://maven.twttr.com + + true + + + false + + spark-1.4-staging @@ -1101,6 +1116,24 @@ ${parquet.version} ${parquet.deps.scope}
+ + org.apache.parquet + parquet-avro + ${parquet.version} + ${parquet.test.deps.scope} + + + org.apache.parquet + parquet-thrift + ${parquet.version} + ${parquet.test.deps.scope} + + + org.apache.thrift + libthrift + ${thrift.version} + ${thrift.test.deps.scope} + org.apache.flume flume-ng-core diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 41e19fd9cc11e..7346d804632bc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,21 +62,8 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticCostFun.this"), // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), - // NanoTime and CatalystTimestampConverter is only used inside catalyst, - // not needed anymore - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.timestamp.NanoTime"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.timestamp.NanoTime$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.CatalystTimestampConverter"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.CatalystTimestampConverter$"), - // SPARK-6777 Implements backwards compatibility rules in CatalystSchemaConverter - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTypeInfo"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTypeInfo$") + // Parquet support is considered private. + excludePackage("org.apache.spark.sql.parquet") ) ++ Seq( // SPARK-8479 Add numNonzeros and numActives to Matrix. ProblemFilters.exclude[MissingMethodProblem]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 7d00047d08d74..a4c2da8e05f5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.types +import scala.util.Try import scala.util.parsing.combinator.RegexParsers -import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ +import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi @@ -82,6 +83,9 @@ abstract class DataType extends AbstractDataType { object DataType { + private[sql] def fromString(raw: String): DataType = { + Try(DataType.fromJson(raw)).getOrElse(DataType.fromCaseClassString(raw)) + } def fromJson(json: String): DataType = parseDataType(parse(json)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3b17566d54d9b..e2d3f53f7d978 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -311,6 +311,11 @@ object StructType extends AbstractDataType { private[sql] override def simpleString: String = "struct" + private[sql] def fromString(raw: String): StructType = DataType.fromString(raw) match { + case t: StructType => t + case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + } + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 8fc16928adbd9..f90099f22d4bd 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -101,9 +101,45 @@ 9.3-1102-jdbc41 test + + org.apache.parquet + parquet-avro + test + + + org.apache.parquet + parquet-thrift + test + + + org.apache.thrift + libthrift + test + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + src/test/gen-java + + + + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala new file mode 100644 index 0000000000000..0c3d8fdab6bd2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.nio.ByteOrder + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ParentContainerUpdater]] is used by a Parquet converter to set converted values to some + * corresponding parent container. For example, a converter for a `StructType` field may set + * converted values to a [[MutableRow]]; or a converter for array elements may append converted + * values to an [[ArrayBuffer]]. + */ +private[parquet] trait ParentContainerUpdater { + def set(value: Any): Unit = () + def setBoolean(value: Boolean): Unit = set(value) + def setByte(value: Byte): Unit = set(value) + def setShort(value: Short): Unit = set(value) + def setInt(value: Int): Unit = set(value) + def setLong(value: Long): Unit = set(value) + def setFloat(value: Float): Unit = set(value) + def setDouble(value: Double): Unit = set(value) +} + +/** A no-op updater used for root converter (who doesn't have a parent). */ +private[parquet] object NoopUpdater extends ParentContainerUpdater + +/** + * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[Row]]s. Since + * any Parquet record is also a struct, this converter can also be used as root converter. + * + * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have + * any "parent" container. + * + * @param parquetType Parquet schema of Parquet records + * @param catalystType Spark SQL schema that corresponds to the Parquet record type + * @param updater An updater which propagates converted field values to the parent container + */ +private[parquet] class CatalystRowConverter( + parquetType: GroupType, + catalystType: StructType, + updater: ParentContainerUpdater) + extends GroupConverter { + + /** + * Updater used together with field converters within a [[CatalystRowConverter]]. It propagates + * converted filed values to the `ordinal`-th cell in `currentRow`. + */ + private final class RowUpdater(row: MutableRow, ordinal: Int) extends ParentContainerUpdater { + override def set(value: Any): Unit = row(ordinal) = value + override def setBoolean(value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(value: Short): Unit = row.setShort(ordinal, value) + override def setInt(value: Int): Unit = row.setInt(ordinal, value) + override def setLong(value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(value: Float): Unit = row.setFloat(ordinal, value) + } + + /** + * Represents the converted row object once an entire Parquet record is converted. + * + * @todo Uses [[UnsafeRow]] for better performance. + */ + val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) + + // Converters for each field. + private val fieldConverters: Array[Converter] = { + parquetType.getFields.zip(catalystType).zipWithIndex.map { + case ((parquetFieldType, catalystField), ordinal) => + // Converted field value should be set to the `ordinal`-th cell of `currentRow` + newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal)) + }.toArray + } + + override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) + + override def end(): Unit = updater.set(currentRow) + + override def start(): Unit = { + var i = 0 + while (i < currentRow.length) { + currentRow.setNullAt(i) + i += 1 + } + } + + /** + * Creates a converter for the given Parquet type `parquetType` and Spark SQL data type + * `catalystType`. Converted values are handled by `updater`. + */ + private def newConverter( + parquetType: Type, + catalystType: DataType, + updater: ParentContainerUpdater): Converter = { + + catalystType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => + new CatalystPrimitiveConverter(updater) + + case ByteType => + new PrimitiveConverter { + override def addInt(value: Int): Unit = + updater.setByte(value.asInstanceOf[ByteType#InternalType]) + } + + case ShortType => + new PrimitiveConverter { + override def addInt(value: Int): Unit = + updater.setShort(value.asInstanceOf[ShortType#InternalType]) + } + + case t: DecimalType => + new CatalystDecimalConverter(t, updater) + + case StringType => + new CatalystStringConverter(updater) + + case TimestampType => + // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. + new PrimitiveConverter { + // Converts nanosecond timestamps stored as INT96 + override def addBinary(value: Binary): Unit = { + assert( + value.length() == 12, + "Timestamps (with nanoseconds) are expected to be stored in 12-byte long binaries, " + + s"but got a ${value.length()}-byte binary.") + + val buf = value.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) + val timeOfDayNanos = buf.getLong + val julianDay = buf.getInt + updater.setLong(DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos)) + } + } + + case DateType => + new PrimitiveConverter { + override def addInt(value: Int): Unit = { + // DateType is not specialized in `SpecificMutableRow`, have to box it here. + updater.set(value.asInstanceOf[DateType#InternalType]) + } + } + + case t: ArrayType => + new CatalystArrayConverter(parquetType.asGroupType(), t, updater) + + case t: MapType => + new CatalystMapConverter(parquetType.asGroupType(), t, updater) + + case t: StructType => + new CatalystRowConverter(parquetType.asGroupType(), t, new ParentContainerUpdater { + override def set(value: Any): Unit = updater.set(value.asInstanceOf[Row].copy()) + }) + + case t: UserDefinedType[_] => + val catalystTypeForUDT = t.sqlType + val nullable = parquetType.isRepetition(Repetition.OPTIONAL) + val field = StructField("udt", catalystTypeForUDT, nullable) + val parquetTypeForUDT = new CatalystSchemaConverter().convertField(field) + newConverter(parquetTypeForUDT, catalystTypeForUDT, updater) + + case _ => + throw new RuntimeException( + s"Unable to create Parquet converter for data type ${catalystType.json}") + } + } + + /** + * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types + * are handled by this converter. Parquet primitive types are only a subset of those of Spark + * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. + */ + private final class CatalystPrimitiveConverter(updater: ParentContainerUpdater) + extends PrimitiveConverter { + + override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) + override def addInt(value: Int): Unit = updater.setInt(value) + override def addLong(value: Long): Unit = updater.setLong(value) + override def addFloat(value: Float): Unit = updater.setFloat(value) + override def addDouble(value: Double): Unit = updater.setDouble(value) + override def addBinary(value: Binary): Unit = updater.set(value.getBytes) + } + + /** + * Parquet converter for strings. A dictionary is used to minimize string decoding cost. + */ + private final class CatalystStringConverter(updater: ParentContainerUpdater) + extends PrimitiveConverter { + + private var expandedDictionary: Array[UTF8String] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { i => + UTF8String.fromBytes(dictionary.decodeToBinary(i).getBytes) + } + } + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } + + override def addBinary(value: Binary): Unit = { + updater.set(UTF8String.fromBytes(value.getBytes)) + } + } + + /** + * Parquet converter for fixed-precision decimals. + */ + private final class CatalystDecimalConverter( + decimalType: DecimalType, + updater: ParentContainerUpdater) + extends PrimitiveConverter { + + // Converts decimals stored as INT32 + override def addInt(value: Int): Unit = { + addLong(value: Long) + } + + // Converts decimals stored as INT64 + override def addLong(value: Long): Unit = { + updater.set(Decimal(value, decimalType.precision, decimalType.scale)) + } + + // Converts decimals stored as either FIXED_LENGTH_BYTE_ARRAY or BINARY + override def addBinary(value: Binary): Unit = { + updater.set(toDecimal(value)) + } + + private def toDecimal(value: Binary): Decimal = { + val precision = decimalType.precision + val scale = decimalType.scale + val bytes = value.getBytes + + var unscaled = 0L + var i = 0 + + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * bytes.length + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + Decimal(unscaled, precision, scale) + } + } + + /** + * Parquet converter for arrays. Spark SQL arrays are represented as Parquet lists. Standard + * Parquet lists are represented as a 3-level group annotated by `LIST`: + * {{{ + * group (LIST) { <-- parquetSchema points here + * repeated group list { + * element; + * } + * } + * }}} + * The `parquetSchema` constructor argument points to the outermost group. + * + * However, before this representation is standardized, some Parquet libraries/tools also use some + * non-standard formats to represent list-like structures. Backwards-compatibility rules for + * handling these cases are described in Parquet format spec. + * + * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + */ + private final class CatalystArrayConverter( + parquetSchema: GroupType, + catalystSchema: ArrayType, + updater: ParentContainerUpdater) + extends GroupConverter { + + private var currentArray: ArrayBuffer[Any] = _ + + private val elementConverter: Converter = { + val repeatedType = parquetSchema.getType(0) + val elementType = catalystSchema.elementType + + if (isElementType(repeatedType, elementType)) { + newConverter(repeatedType, elementType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentArray += value + }) + } else { + new ElementConverter(repeatedType.asGroupType().getType(0), elementType) + } + } + + override def getConverter(fieldIndex: Int): Converter = elementConverter + + override def end(): Unit = updater.set(currentArray) + + // NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the + // next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored + // in row cells. + override def start(): Unit = currentArray = ArrayBuffer.empty[Any] + + // scalastyle:off + /** + * Returns whether the given type is the element type of a list or is a syntactic group with + * one field that is the element type. This is determined by checking whether the type can be + * a syntactic group and by checking whether a potential syntactic group matches the expected + * schema. + * {{{ + * group (LIST) { + * repeated group list { <-- repeatedType points here + * element; + * } + * } + * }}} + * In short, here we handle Parquet list backwards-compatibility rules on the read path. This + * method is based on `AvroIndexedRecordConverter.isElementType`. + * + * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + */ + // scalastyle:on + private def isElementType(parquetRepeatedType: Type, catalystElementType: DataType): Boolean = { + (parquetRepeatedType, catalystElementType) match { + case (t: PrimitiveType, _) => true + case (t: GroupType, _) if t.getFieldCount > 1 => true + case (t: GroupType, StructType(Array(f))) if f.name == t.getFieldName(0) => true + case _ => false + } + } + + /** Array element converter */ + private final class ElementConverter(parquetType: Type, catalystType: DataType) + extends GroupConverter { + + private var currentElement: Any = _ + + private val converter = newConverter(parquetType, catalystType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentElement = value + }) + + override def getConverter(fieldIndex: Int): Converter = converter + + override def end(): Unit = currentArray += currentElement + + override def start(): Unit = currentElement = null + } + } + + /** Parquet converter for maps */ + private final class CatalystMapConverter( + parquetType: GroupType, + catalystType: MapType, + updater: ParentContainerUpdater) + extends GroupConverter { + + private var currentMap: mutable.Map[Any, Any] = _ + + private val keyValueConverter = { + val repeatedType = parquetType.getType(0).asGroupType() + new KeyValueConverter( + repeatedType.getType(0), + repeatedType.getType(1), + catalystType.keyType, + catalystType.valueType) + } + + override def getConverter(fieldIndex: Int): Converter = keyValueConverter + + override def end(): Unit = updater.set(currentMap) + + // NOTE: We can't reuse the mutable Map here and must instantiate a new `Map` for the next + // value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored in row + // cells. + override def start(): Unit = currentMap = mutable.Map.empty[Any, Any] + + /** Parquet converter for key-value pairs within the map. */ + private final class KeyValueConverter( + parquetKeyType: Type, + parquetValueType: Type, + catalystKeyType: DataType, + catalystValueType: DataType) + extends GroupConverter { + + private var currentKey: Any = _ + + private var currentValue: Any = _ + + private val converters = Array( + // Converter for keys + newConverter(parquetKeyType, catalystKeyType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentKey = value + }), + + // Converter for values + newConverter(parquetValueType, catalystValueType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentValue = value + })) + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + override def end(): Unit = currentMap(currentKey) = currentValue + + override def start(): Unit = { + currentKey = null + currentValue = null + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 4ab274ec17a02..de3a72d8146c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -358,9 +358,24 @@ private[parquet] class CatalystSchemaConverter( case DateType => Types.primitive(INT32, repetition).as(DATE).named(field.name) - // NOTE: !! This timestamp type is not specified in Parquet format spec !! - // However, Impala and older versions of Spark SQL use INT96 to store timestamps with - // nanosecond precision (not TIME_MILLIS or TIMESTAMP_MILLIS described in the spec). + // NOTE: Spark SQL TimestampType is NOT a well defined type in Parquet format spec. + // + // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond + // timestamp in Impala for some historical reasons, it's not recommended to be used for any + // other types and will probably be deprecated in future Parquet format spec. That's the + // reason why Parquet format spec only defines `TIMESTAMP_MILLIS` and `TIMESTAMP_MICROS` which + // are both logical types annotating `INT64`. + // + // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting + // from Spark 1.5.0, we resort to a timestamp type with 100 ns precision so that we can store + // a timestamp into a `Long`. This design decision is subject to change though, for example, + // we may resort to microsecond precision in the future. + // + // For Parquet, we plan to write all `TimestampType` value as `TIMESTAMP_MICROS`, but it's + // currently not implemented yet because parquet-mr 1.7.0 (the version we're currently using) + // hasn't implemented `TIMESTAMP_MICROS` yet. + // + // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. case TimestampType => Types.primitive(INT96, repetition).named(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 86a77bf965daa..be0a2029d233b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,61 +17,15 @@ package org.apache.spark.sql.parquet -import java.nio.ByteOrder - -import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap} - -import org.apache.parquet.Preconditions -import org.apache.parquet.column.Dictionary -import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.MessageType - import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.parquet.CatalystConverter.FieldType -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * Collection of converters of Parquet types (group and primitive types) that - * model arrays and maps. The conversions are partly based on the AvroParquet - * converters that are part of Parquet in order to be able to process these - * types. - * - * There are several types of converters: - *
    - *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveConverter]] for primitive - * (numeric, boolean and String) types
  • - *
  • [[org.apache.spark.sql.parquet.CatalystNativeArrayConverter]] for arrays - * of native JVM element types; note: currently null values are not supported!
  • - *
  • [[org.apache.spark.sql.parquet.CatalystArrayConverter]] for arrays of - * arbitrary element types (including nested element types); note: currently - * null values are not supported!
  • - *
  • [[org.apache.spark.sql.parquet.CatalystStructConverter]] for structs
  • - *
  • [[org.apache.spark.sql.parquet.CatalystMapConverter]] for maps; note: - * currently null values are not supported!
  • - *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveRowConverter]] for rows - * of only primitive element types
  • - *
  • [[org.apache.spark.sql.parquet.CatalystGroupConverter]] for other nested - * records, including the top-level row record
  • - *
- */ private[sql] object CatalystConverter { - // The type internally used for fields - type FieldType = StructField - // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). // Note that "array" for the array elements is chosen by ParquetAvro. // Using a different value will result in Parquet silently dropping columns. val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag" val ARRAY_ELEMENTS_SCHEMA_NAME = "array" - // SPARK-4520: Thrift generated parquet files have different array element - // schema names than avro. Thrift parquet uses array_schema_name + "_tuple" - // as opposed to "array" used by default. For more information, check - // TestThriftSchemaConverter.java in parquet.thrift. - val THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX = "_tuple" + val MAP_KEY_SCHEMA_NAME = "key" val MAP_VALUE_SCHEMA_NAME = "value" val MAP_SCHEMA_NAME = "map" @@ -80,787 +34,4 @@ private[sql] object CatalystConverter { type ArrayScalaType[T] = Seq[T] type StructScalaType[T] = InternalRow type MapScalaType[K, V] = Map[K, V] - - protected[parquet] def createConverter( - field: FieldType, - fieldIndex: Int, - parent: CatalystConverter): Converter = { - val fieldType: DataType = field.dataType - fieldType match { - case udt: UserDefinedType[_] => { - createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) - } - // For native JVM types we use a converter with native arrays - case ArrayType(elementType: AtomicType, false) => { - new CatalystNativeArrayConverter(elementType, fieldIndex, parent) - } - // This is for other types of arrays, including those with nested fields - case ArrayType(elementType: DataType, false) => { - new CatalystArrayConverter(elementType, fieldIndex, parent) - } - case ArrayType(elementType: DataType, true) => { - new CatalystArrayContainsNullConverter(elementType, fieldIndex, parent) - } - case StructType(fields: Array[StructField]) => { - new CatalystStructConverter(fields, fieldIndex, parent) - } - case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => { - new CatalystMapConverter( - Array( - new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), - new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)), - fieldIndex, - parent) - } - // Strings, Shorts and Bytes do not have a corresponding type in Parquet - // so we need to treat them separately - case StringType => - new CatalystPrimitiveStringConverter(parent, fieldIndex) - case ShortType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.InternalType]) - } - } - case ByteType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.InternalType]) - } - } - case DateType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateDate(fieldIndex, value.asInstanceOf[DateType.InternalType]) - } - } - case d: DecimalType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateDecimal(fieldIndex, value, d) - } - } - case TimestampType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateTimestamp(fieldIndex, value) - } - } - // All other primitive types use the default converter - case ctype: DataType if ParquetTypesConverter.isPrimitiveType(ctype) => { - // note: need the type tag here! - new CatalystPrimitiveConverter(parent, fieldIndex) - } - case _ => throw new RuntimeException( - s"unable to convert datatype ${field.dataType.toString} in CatalystConverter") - } - } - - protected[parquet] def createRootConverter( - parquetSchema: MessageType, - attributes: Seq[Attribute]): CatalystConverter = { - // For non-nested types we use the optimized Row converter - if (attributes.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))) { - new CatalystPrimitiveRowConverter(attributes.toArray) - } else { - new CatalystGroupConverter(attributes.toArray) - } - } -} - -private[parquet] abstract class CatalystConverter extends GroupConverter { - /** - * The number of fields this group has - */ - protected[parquet] val size: Int - - /** - * The index of this converter in the parent - */ - protected[parquet] val index: Int - - /** - * The parent converter - */ - protected[parquet] val parent: CatalystConverter - - /** - * Called by child converters to update their value in its parent (this). - * Note that if possible the more specific update methods below should be used - * to avoid auto-boxing of native JVM types. - * - * @param fieldIndex - * @param value - */ - protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit - - protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, value.getBytes) - - protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - updateField(fieldIndex, UTF8String.fromBytes(value)) - - protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, readTimestamp(value)) - - protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = - updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) - - protected[parquet] def isRootConverter: Boolean = parent == null - - protected[parquet] def clearBuffer(): Unit - - /** - * Should only be called in the root (group) converter! - * - * @return - */ - def getCurrentRecord: InternalRow = throw new UnsupportedOperationException - - /** - * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in - * a long (i.e. precision <= 18) - * - * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object. - */ - protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = { - val precision = ctype.precisionInfo.get.precision - val scale = ctype.precisionInfo.get.scale - val bytes = value.getBytes - require(bytes.length <= 16, "Decimal field too large to read") - var unscaled = 0L - var i = 0 - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xFF) - i += 1 - } - // Make sure unscaled has the right sign, by sign-extending the first bit - val numBits = 8 * bytes.length - unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) - dest.set(unscaled, precision, scale) - } - - /** - * Read a Timestamp value from a Parquet Int96Value - */ - protected[parquet] def readTimestamp(value: Binary): Long = { - Preconditions.checkArgument(value.length() == 12, "Must be 12 bytes") - val buf = value.toByteBuffer - buf.order(ByteOrder.LITTLE_ENDIAN) - val timeOfDayNanos = buf.getLong - val julianDay = buf.getInt - DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) - } -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. - * - * @param schema The corresponding Catalyst schema in the form of a list of attributes. - */ -private[parquet] class CatalystGroupConverter( - protected[parquet] val schema: Array[FieldType], - protected[parquet] val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var current: ArrayBuffer[Any], - protected[parquet] var buffer: ArrayBuffer[InternalRow]) - extends CatalystConverter { - - def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = - this( - schema, - index, - parent, - current = null, - buffer = new ArrayBuffer[InternalRow]( - CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - /** - * This constructor is used for the root converter only! - */ - def this(attributes: Array[Attribute]) = - this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null) - - protected [parquet] val converters: Array[Converter] = - schema.zipWithIndex.map { - case (field, idx) => CatalystConverter.createConverter(field, idx, this) - }.toArray - - override val size = schema.size - - override def getCurrentRecord: InternalRow = { - assert(isRootConverter, "getCurrentRecord should only be called in root group converter!") - // TODO: use iterators if possible - // Note: this will ever only be called in the root converter when the record has been - // fully processed. Therefore it will be difficult to use mutable rows instead, since - // any non-root converter never would be sure when it would be safe to re-use the buffer. - new GenericInternalRow(current.toArray) - } - - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) - - // for child converters to update upstream values - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - current.update(fieldIndex, value) - } - - override protected[parquet] def clearBuffer(): Unit = buffer.clear() - - override def start(): Unit = { - current = ArrayBuffer.fill(size)(null) - converters.foreach { converter => - if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer() - } - } - } - - override def end(): Unit = { - if (!isRootConverter) { - assert(current != null) // there should be no empty groups - buffer.append(new GenericInternalRow(current.toArray)) - parent.updateField(index, new GenericInternalRow(buffer.toArray.asInstanceOf[Array[Any]])) - } - } -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. Note that his - * converter is optimized for rows of primitive types (non-nested records). - */ -private[parquet] class CatalystPrimitiveRowConverter( - protected[parquet] val schema: Array[FieldType], - protected[parquet] var current: MutableRow) - extends CatalystConverter { - - // This constructor is used for the root converter only - def this(attributes: Array[Attribute]) = - this( - attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), - new SpecificMutableRow(attributes.map(_.dataType))) - - protected [parquet] val converters: Array[Converter] = - schema.zipWithIndex.map { - case (field, idx) => CatalystConverter.createConverter(field, idx, this) - }.toArray - - override val size = schema.size - - override val index = 0 - - override val parent = null - - // Should be only called in root group converter! - override def getCurrentRecord: InternalRow = current - - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) - - // for child converters to update upstream values - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - throw new UnsupportedOperationException // child converters should use the - // specific update methods below - } - - override protected[parquet] def clearBuffer(): Unit = {} - - override def start(): Unit = { - var i = 0 - while (i < size) { - current.setNullAt(i) - i = i + 1 - } - } - - override def end(): Unit = {} - - // Overridden here to avoid auto-boxing for primitive types - override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = - current.setBoolean(fieldIndex, value) - - override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = - current.setInt(fieldIndex, value) - - override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - current.setInt(fieldIndex, value) - - override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = - current.setLong(fieldIndex, value) - - override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = - current.setShort(fieldIndex, value) - - override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = - current.setByte(fieldIndex, value) - - override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = - current.setDouble(fieldIndex, value) - - override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = - current.setFloat(fieldIndex, value) - - override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = - current.update(fieldIndex, value.getBytes) - - override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - current.update(fieldIndex, UTF8String.fromBytes(value)) - - override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - current.setLong(fieldIndex, readTimestamp(value)) - - override protected[parquet] def updateDecimal( - fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { - var decimal = current(fieldIndex).asInstanceOf[Decimal] - if (decimal == null) { - decimal = new Decimal - current(fieldIndex) = decimal - } - readDecimal(decimal, value, ctype) - } -} - -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveConverter( - parent: CatalystConverter, - fieldIndex: Int) extends PrimitiveConverter { - override def addBinary(value: Binary): Unit = - parent.updateBinary(fieldIndex, value) - - override def addBoolean(value: Boolean): Unit = - parent.updateBoolean(fieldIndex, value) - - override def addDouble(value: Double): Unit = - parent.updateDouble(fieldIndex, value) - - override def addFloat(value: Float): Unit = - parent.updateFloat(fieldIndex, value) - - override def addInt(value: Int): Unit = - parent.updateInt(fieldIndex, value) - - override def addLong(value: Long): Unit = - parent.updateLong(fieldIndex, value) -} - -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet Binary to Catalyst String. - * Supports dictionaries to reduce Binary to String conversion overhead. - * - * Follows pattern in Parquet of using dictionaries, where supported, for String conversion. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) - extends CatalystPrimitiveConverter(parent, fieldIndex) { - - private[this] var dict: Array[Array[Byte]] = null - - override def hasDictionarySupport: Boolean = true - - override def setDictionary(dictionary: Dictionary): Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } - - override def addValueFromDictionary(dictionaryId: Int): Unit = - parent.updateString(fieldIndex, dict(dictionaryId)) - - override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value.getBytes) -} - -private[parquet] object CatalystArrayConverter { - val INITIAL_ARRAY_SIZE = 20 -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (complex or primitive) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param buffer A data buffer - */ -private[parquet] class CatalystArrayConverter( - val elementType: DataType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var buffer: Buffer[Any]) - extends CatalystConverter { - - def this(elementType: DataType, index: Int, parent: CatalystConverter) = - this( - elementType, - index, - parent, - new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - protected[parquet] val converter: Converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - // fieldIndex is ignored (assumed to be zero but not checked) - if (value == null) { - throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") - } - buffer += value - } - - override protected[parquet] def clearBuffer(): Unit = { - buffer.clear() - } - - override def start(): Unit = { - if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer() - } - } - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField(index, buffer.toArray.toSeq) - clearBuffer() - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (native) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param capacity The (initial) capacity of the buffer - */ -private[parquet] class CatalystNativeArrayConverter( - val elementType: AtomicType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) - extends CatalystConverter { - - type NativeType = elementType.InternalType - - private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) - - private var elements: Int = 0 - - protected[parquet] val converter: Converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = - throw new UnsupportedOperationException - - // Overridden here to avoid auto-boxing for primitive types - override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = { - checkGrowBuffer() - buffer(elements) = value.getBytes.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { - checkGrowBuffer() - buffer(elements) = UTF8String.fromBytes(value).asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def clearBuffer(): Unit = { - elements = 0 - } - - override def start(): Unit = {} - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField( - index, - buffer.slice(0, elements).toSeq) - clearBuffer() - } - - private def checkGrowBuffer(): Unit = { - if (elements >= capacity) { - val newCapacity = 2 * capacity - val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity) - Array.copy(buffer, 0, tmp, 0, capacity) - buffer = tmp - capacity = newCapacity - } - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array contains null (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (complex or primitive) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param buffer A data buffer - */ -private[parquet] class CatalystArrayContainsNullConverter( - val elementType: DataType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var buffer: Buffer[Any]) - extends CatalystConverter { - - def this(elementType: DataType, index: Int, parent: CatalystConverter) = - this( - elementType, - index, - parent, - new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - protected[parquet] val converter: Converter = new CatalystConverter { - - private var current: Any = null - - val converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - override def end(): Unit = parent.updateField(index, current) - - override def start(): Unit = { - current = null - } - - override protected[parquet] val size: Int = 1 - override protected[parquet] val index: Int = 0 - override protected[parquet] val parent = CatalystArrayContainsNullConverter.this - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - current = value - } - - override protected[parquet] def clearBuffer(): Unit = {} - } - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - buffer += value - } - - override protected[parquet] def clearBuffer(): Unit = { - buffer.clear() - } - - override def start(): Unit = {} - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField(index, buffer.toArray.toSeq) - clearBuffer() - } -} - -/** - * This converter is for multi-element groups of primitive or complex types - * that have repetition level optional or required (so struct fields). - * - * @param schema The corresponding Catalyst schema in the form of a list of - * attributes. - * @param index - * @param parent - */ -private[parquet] class CatalystStructConverter( - override protected[parquet] val schema: Array[FieldType], - override protected[parquet] val index: Int, - override protected[parquet] val parent: CatalystConverter) - extends CatalystGroupConverter(schema, index, parent) { - - override protected[parquet] def clearBuffer(): Unit = {} - - // TODO: think about reusing the buffer - override def end(): Unit = { - assert(!isRootConverter) - // here we need to make sure to use StructScalaType - // Note: we need to actually make a copy of the array since we - // may be in a nested field - parent.updateField(index, new GenericInternalRow(current.toArray)) - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts two-element groups that - * match the characteristics of a map (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.MapType]]. - * - * @param schema - * @param index - * @param parent - */ -private[parquet] class CatalystMapConverter( - protected[parquet] val schema: Array[FieldType], - override protected[parquet] val index: Int, - override protected[parquet] val parent: CatalystConverter) - extends CatalystConverter { - - private val map = new HashMap[Any, Any]() - - private val keyValueConverter = new CatalystConverter { - private var currentKey: Any = null - private var currentValue: Any = null - val keyConverter = CatalystConverter.createConverter(schema(0), 0, this) - val valueConverter = CatalystConverter.createConverter(schema(1), 1, this) - - override def getConverter(fieldIndex: Int): Converter = { - if (fieldIndex == 0) keyConverter else valueConverter - } - - override def end(): Unit = CatalystMapConverter.this.map += currentKey -> currentValue - - override def start(): Unit = { - currentKey = null - currentValue = null - } - - override protected[parquet] val size: Int = 2 - override protected[parquet] val index: Int = 0 - override protected[parquet] val parent: CatalystConverter = CatalystMapConverter.this - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - fieldIndex match { - case 0 => - currentKey = value - case 1 => - currentValue = value - case _ => - new RuntimePermission(s"trying to update Map with fieldIndex $fieldIndex") - } - } - - override protected[parquet] def clearBuffer(): Unit = {} - } - - override protected[parquet] val size: Int = 1 - - override protected[parquet] def clearBuffer(): Unit = {} - - override def start(): Unit = { - map.clear() - } - - override def end(): Unit = { - // here we need to make sure to use MapScalaType - parent.updateField(index, map.toMap) - } - - override def getConverter(fieldIndex: Int): Converter = keyValueConverter - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = - throw new UnsupportedOperationException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 8402cd756140d..e8851ddb68026 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.parquet -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} +import java.util import java.util.{HashMap => JHashMap} +import scala.collection.JavaConversions._ + import org.apache.hadoop.conf.Configuration import org.apache.parquet.column.ParquetProperties import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.api.ReadSupport.ReadContext -import org.apache.parquet.hadoop.api.{ReadSupport, WriteSupport} +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport, WriteSupport} import org.apache.parquet.io.api._ import org.apache.parquet.schema.MessageType @@ -36,87 +39,133 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * A `parquet.io.api.RecordMaterializer` for Rows. + * A [[RecordMaterializer]] for Catalyst rows. * - *@param root The root group converter for the record. + * @param parquetSchema Parquet schema of the records to be read + * @param catalystSchema Catalyst schema of the rows to be constructed */ -private[parquet] class RowRecordMaterializer(root: CatalystConverter) +private[parquet] class RowRecordMaterializer(parquetSchema: MessageType, catalystSchema: StructType) extends RecordMaterializer[InternalRow] { - def this(parquetSchema: MessageType, attributes: Seq[Attribute]) = - this(CatalystConverter.createRootConverter(parquetSchema, attributes)) + private val rootConverter = new CatalystRowConverter(parquetSchema, catalystSchema, NoopUpdater) - override def getCurrentRecord: InternalRow = root.getCurrentRecord + override def getCurrentRecord: InternalRow = rootConverter.currentRow - override def getRootConverter: GroupConverter = root.asInstanceOf[GroupConverter] + override def getRootConverter: GroupConverter = rootConverter } -/** - * A `parquet.hadoop.api.ReadSupport` for Row objects. - */ private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logging { - override def prepareForRead( conf: Configuration, - stringMap: java.util.Map[String, String], + keyValueMetaData: util.Map[String, String], fileSchema: MessageType, readContext: ReadContext): RecordMaterializer[InternalRow] = { - log.debug(s"preparing for read with Parquet file schema $fileSchema") - // Note: this very much imitates AvroParquet - val parquetSchema = readContext.getRequestedSchema - var schema: Seq[Attribute] = null - - if (readContext.getReadSupportMetadata != null) { - // first try to find the read schema inside the metadata (can result from projections) - if ( - readContext - .getReadSupportMetadata - .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) != null) { - schema = ParquetTypesConverter.convertFromString( - readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) - } else { - // if unavailable, try the schema that was read originally from the file or provided - // during the creation of the Parquet relation - if (readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { - schema = ParquetTypesConverter.convertFromString( - readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) - } + log.debug(s"Preparing for read Parquet file with message type: $fileSchema") + + val toCatalyst = new CatalystSchemaConverter(conf) + val parquetRequestedSchema = readContext.getRequestedSchema + + val catalystRequestedSchema = + Option(readContext.getReadSupportMetadata).map(_.toMap).flatMap { metadata => + metadata + // First tries to read requested schema, which may result from projections + .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + // If not available, tries to read Catalyst schema from file metadata. It's only + // available if the target file is written by Spark SQL. + .orElse(metadata.get(RowReadSupport.SPARK_METADATA_KEY)) + }.map(StructType.fromString).getOrElse { + logDebug("Catalyst schema not available, falling back to Parquet schema") + toCatalyst.convert(parquetRequestedSchema) } - } - // if both unavailable, fall back to deducing the schema from the given Parquet schema - // TODO: Why it can be null? - if (schema == null) { - log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false, true) - } - log.debug(s"list of attributes that will be read: $schema") - new RowRecordMaterializer(parquetSchema, schema) + + logDebug(s"Catalyst schema used to read Parquet files: $catalystRequestedSchema") + new RowRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema) } - override def init( - configuration: Configuration, - keyValueMetaData: java.util.Map[String, String], - fileSchema: MessageType): ReadContext = { - var parquetSchema = fileSchema - val metadata = new JHashMap[String, String]() - val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) - - if (requestedAttributes != null) { - // If the parquet file is thrift derived, there is a good chance that - // it will have the thrift class in metadata. - val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class") - parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) - metadata.put( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(requestedAttributes)) - } + override def init(context: InitContext): ReadContext = { + val conf = context.getConfiguration + + // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst + // schema of this file from its the metadata. + val maybeRowSchema = Option(conf.get(RowWriteSupport.SPARK_ROW_SCHEMA)) + + // Optional schema of requested columns, in the form of a string serialized from a Catalyst + // `StructType` containing all requested columns. + val maybeRequestedSchema = Option(conf.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) + + // Below we construct a Parquet schema containing all requested columns. This schema tells + // Parquet which columns to read. + // + // If `maybeRequestedSchema` is defined, we assemble an equivalent Parquet schema. Otherwise, + // we have to fallback to the full file schema which contains all columns in the file. + // Obviously this may waste IO bandwidth since it may read more columns than requested. + // + // Two things to note: + // + // 1. It's possible that some requested columns don't exist in the target Parquet file. For + // example, in the case of schema merging, the globally merged schema may contain extra + // columns gathered from other Parquet files. These columns will be simply filled with nulls + // when actually reading the target Parquet file. + // + // 2. When `maybeRequestedSchema` is available, we can't simply convert the Catalyst schema to + // Parquet schema using `CatalystSchemaConverter`, because the mapping is not unique due to + // non-standard behaviors of some Parquet libraries/tools. For example, a Parquet file + // containing a single integer array field `f1` may have the following legacy 2-level + // structure: + // + // message root { + // optional group f1 (LIST) { + // required INT32 element; + // } + // } + // + // while `CatalystSchemaConverter` may generate a standard 3-level structure: + // + // message root { + // optional group f1 (LIST) { + // repeated group list { + // required INT32 element; + // } + // } + // } + // + // Apparently, we can't use the 2nd schema to read the target Parquet file as they have + // different physical structures. + val parquetRequestedSchema = + maybeRequestedSchema.fold(context.getFileSchema) { schemaString => + val toParquet = new CatalystSchemaConverter(conf) + val fileSchema = context.getFileSchema.asGroupType() + val fileFieldNames = fileSchema.getFields.map(_.getName).toSet + + StructType + // Deserializes the Catalyst schema of requested columns + .fromString(schemaString) + .map { field => + if (fileFieldNames.contains(field.name)) { + // If the field exists in the target Parquet file, extracts the field type from the + // full file schema and makes a single-field Parquet schema + new MessageType("root", fileSchema.getType(field.name)) + } else { + // Otherwise, just resorts to `CatalystSchemaConverter` + toParquet.convert(StructType(Array(field))) + } + } + // Merges all single-field Parquet schemas to form a complete schema for all requested + // columns. Note that it's possible that no columns are requested at all (e.g., count + // some partition column of a partitioned Parquet table). That's why `fold` is used here + // and always fallback to an empty Parquet schema. + .fold(new MessageType("root")) { + _ union _ + } + } - val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - if (origAttributesStr != null) { - metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) - } + val metadata = + Map.empty[String, String] ++ + maybeRequestedSchema.map(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ + maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) - new ReadSupport.ReadContext(parquetSchema, metadata) + logInfo(s"Going to read Parquet file with these requested columns: $parquetRequestedSchema") + new ReadContext(parquetRequestedSchema, metadata) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index ce456e7fbe17e..01dd6f471bd7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -259,6 +259,10 @@ private[sql] class ParquetRelation2( broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown + val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + // Create the function to set variable Parquet confs at both driver and executor side. val initLocalJobFuncOpt = ParquetRelation2.initializeLocalJobFunc( @@ -266,7 +270,11 @@ private[sql] class ParquetRelation2( filters, dataSchema, useMetadataCache, - parquetFilterPushDown) _ + parquetFilterPushDown, + assumeBinaryIsString, + assumeInt96IsTimestamp, + followParquetFormatSpec) _ + // Create the function to set input paths at the driver side. val setInputPaths = ParquetRelation2.initializeDriverSideJobFunc(inputFiles) _ @@ -471,9 +479,12 @@ private[sql] object ParquetRelation2 extends Logging { filters: Array[Filter], dataSchema: StructType, useMetadataCache: Boolean, - parquetFilterPushDown: Boolean)(job: Job): Unit = { + parquetFilterPushDown: Boolean, + assumeBinaryIsString: Boolean, + assumeInt96IsTimestamp: Boolean, + followParquetFormatSpec: Boolean)(job: Job): Unit = { val conf = job.getConfiguration - conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[RowReadSupport].getName()) + conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[RowReadSupport].getName) // Try to push down filters when filter push-down is enabled. if (parquetFilterPushDown) { @@ -497,6 +508,11 @@ private[sql] object ParquetRelation2 extends Logging { // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) + + // Sets flags for Parquet schema conversion + conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) + conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) + conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) } /** This closure sets input paths at the driver side. */ diff --git a/sql/core/src/test/README.md b/sql/core/src/test/README.md new file mode 100644 index 0000000000000..3dd9861b4896d --- /dev/null +++ b/sql/core/src/test/README.md @@ -0,0 +1,33 @@ +# Notes for Parquet compatibility tests + +The following directories and files are used for Parquet compatibility tests: + +``` +. +├── README.md # This file +├── avro +│   ├── parquet-compat.avdl # Testing Avro IDL +│   └── parquet-compat.avpr # !! NO TOUCH !! Protocol file generated from parquet-compat.avdl +├── gen-java # !! NO TOUCH !! Generated Java code +├── scripts +│   └── gen-code.sh # Script used to generate Java code for Thrift and Avro +└── thrift + └── parquet-compat.thrift # Testing Thrift schema +``` + +Generated Java code are used in the following test suites: + +- `org.apache.spark.sql.parquet.ParquetAvroCompatibilitySuite` +- `org.apache.spark.sql.parquet.ParquetThriftCompatibilitySuite` + +To avoid code generation during build time, Java code generated from testing Thrift schema and Avro IDL are also checked in. + +When updating the testing Thrift schema and Avro IDL, please run `gen-code.sh` to update all the generated Java code. + +## Prerequisites + +Please ensure `avro-tools` and `thrift` are installed. You may install these two on Mac OS X via: + +```bash +$ brew install thrift avro-tools +``` diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl new file mode 100644 index 0000000000000..24729f6143e6c --- /dev/null +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This is a test protocol for testing parquet-avro compatibility. +@namespace("org.apache.spark.sql.parquet.test.avro") +protocol CompatibilityTest { + record Nested { + array nested_ints_column; + string nested_string_column; + } + + record ParquetAvroCompat { + boolean bool_column; + int int_column; + long long_column; + float float_column; + double double_column; + bytes binary_column; + string string_column; + + union { null, boolean } maybe_bool_column; + union { null, int } maybe_int_column; + union { null, long } maybe_long_column; + union { null, float } maybe_float_column; + union { null, double } maybe_double_column; + union { null, bytes } maybe_binary_column; + union { null, string } maybe_string_column; + + array strings_column; + map string_to_int_column; + map> complex_column; + } +} diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr new file mode 100644 index 0000000000000..a83b7c990dd2e --- /dev/null +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -0,0 +1,86 @@ +{ + "protocol" : "CompatibilityTest", + "namespace" : "org.apache.spark.sql.parquet.test.avro", + "types" : [ { + "type" : "record", + "name" : "Nested", + "fields" : [ { + "name" : "nested_ints_column", + "type" : { + "type" : "array", + "items" : "int" + } + }, { + "name" : "nested_string_column", + "type" : "string" + } ] + }, { + "type" : "record", + "name" : "ParquetAvroCompat", + "fields" : [ { + "name" : "bool_column", + "type" : "boolean" + }, { + "name" : "int_column", + "type" : "int" + }, { + "name" : "long_column", + "type" : "long" + }, { + "name" : "float_column", + "type" : "float" + }, { + "name" : "double_column", + "type" : "double" + }, { + "name" : "binary_column", + "type" : "bytes" + }, { + "name" : "string_column", + "type" : "string" + }, { + "name" : "maybe_bool_column", + "type" : [ "null", "boolean" ] + }, { + "name" : "maybe_int_column", + "type" : [ "null", "int" ] + }, { + "name" : "maybe_long_column", + "type" : [ "null", "long" ] + }, { + "name" : "maybe_float_column", + "type" : [ "null", "float" ] + }, { + "name" : "maybe_double_column", + "type" : [ "null", "double" ] + }, { + "name" : "maybe_binary_column", + "type" : [ "null", "bytes" ] + }, { + "name" : "maybe_string_column", + "type" : [ "null", "string" ] + }, { + "name" : "strings_column", + "type" : { + "type" : "array", + "items" : "string" + } + }, { + "name" : "string_to_int_column", + "type" : { + "type" : "map", + "values" : "int" + } + }, { + "name" : "complex_column", + "type" : { + "type" : "map", + "values" : { + "type" : "array", + "items" : "Nested" + } + } + } ] + } ], + "messages" : { } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java new file mode 100644 index 0000000000000..daec65a5bbe57 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java @@ -0,0 +1,17 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.parquet.test.avro; + +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public interface CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"types\":[{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + + @SuppressWarnings("all") + public interface Callback extends CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.parquet.test.avro.CompatibilityTest.PROTOCOL; + } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java new file mode 100644 index 0000000000000..051f1ee903863 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List nested_ints_column; + @Deprecated public java.lang.String nested_string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public Nested() {} + + /** + * All-args constructor. + */ + public Nested(java.util.List nested_ints_column, java.lang.String nested_string_column) { + this.nested_ints_column = nested_ints_column; + this.nested_string_column = nested_string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return nested_ints_column; + case 1: return nested_string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: nested_ints_column = (java.util.List)value$; break; + case 1: nested_string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'nested_ints_column' field. + */ + public java.util.List getNestedIntsColumn() { + return nested_ints_column; + } + + /** + * Sets the value of the 'nested_ints_column' field. + * @param value the value to set. + */ + public void setNestedIntsColumn(java.util.List value) { + this.nested_ints_column = value; + } + + /** + * Gets the value of the 'nested_string_column' field. + */ + public java.lang.String getNestedStringColumn() { + return nested_string_column; + } + + /** + * Sets the value of the 'nested_string_column' field. + * @param value the value to set. + */ + public void setNestedStringColumn(java.lang.String value) { + this.nested_string_column = value; + } + + /** Creates a new Nested RecordBuilder */ + public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder() { + return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(); + } + + /** Creates a new Nested RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { + return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + } + + /** Creates a new Nested RecordBuilder by copying an existing Nested instance */ + public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested other) { + return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + } + + /** + * RecordBuilder for Nested instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List nested_ints_column; + private java.lang.String nested_string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { + super(other); + if (isValidValue(fields()[0], other.nested_ints_column)) { + this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested_string_column)) { + this.nested_string_column = data().deepCopy(fields()[1].schema(), other.nested_string_column); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing Nested instance */ + private Builder(org.apache.spark.sql.parquet.test.avro.Nested other) { + super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + if (isValidValue(fields()[0], other.nested_ints_column)) { + this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested_string_column)) { + this.nested_string_column = data().deepCopy(fields()[1].schema(), other.nested_string_column); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'nested_ints_column' field */ + public java.util.List getNestedIntsColumn() { + return nested_ints_column; + } + + /** Sets the value of the 'nested_ints_column' field */ + public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { + validate(fields()[0], value); + this.nested_ints_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'nested_ints_column' field has been set */ + public boolean hasNestedIntsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'nested_ints_column' field */ + public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { + nested_ints_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'nested_string_column' field */ + public java.lang.String getNestedStringColumn() { + return nested_string_column; + } + + /** Sets the value of the 'nested_string_column' field */ + public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { + validate(fields()[1], value); + this.nested_string_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'nested_string_column' field has been set */ + public boolean hasNestedStringColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'nested_string_column' field */ + public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedStringColumn() { + nested_string_column = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public Nested build() { + try { + Nested record = new Nested(); + record.nested_ints_column = fieldSetFlags()[0] ? this.nested_ints_column : (java.util.List) defaultValue(fields()[0]); + record.nested_string_column = fieldSetFlags()[1] ? this.nested_string_column : (java.lang.String) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java new file mode 100644 index 0000000000000..354c9d73cca31 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java @@ -0,0 +1,1001 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public boolean bool_column; + @Deprecated public int int_column; + @Deprecated public long long_column; + @Deprecated public float float_column; + @Deprecated public double double_column; + @Deprecated public java.nio.ByteBuffer binary_column; + @Deprecated public java.lang.String string_column; + @Deprecated public java.lang.Boolean maybe_bool_column; + @Deprecated public java.lang.Integer maybe_int_column; + @Deprecated public java.lang.Long maybe_long_column; + @Deprecated public java.lang.Float maybe_float_column; + @Deprecated public java.lang.Double maybe_double_column; + @Deprecated public java.nio.ByteBuffer maybe_binary_column; + @Deprecated public java.lang.String maybe_string_column; + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.Map string_to_int_column; + @Deprecated public java.util.Map> complex_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetAvroCompat() {} + + /** + * All-args constructor. + */ + public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { + this.bool_column = bool_column; + this.int_column = int_column; + this.long_column = long_column; + this.float_column = float_column; + this.double_column = double_column; + this.binary_column = binary_column; + this.string_column = string_column; + this.maybe_bool_column = maybe_bool_column; + this.maybe_int_column = maybe_int_column; + this.maybe_long_column = maybe_long_column; + this.maybe_float_column = maybe_float_column; + this.maybe_double_column = maybe_double_column; + this.maybe_binary_column = maybe_binary_column; + this.maybe_string_column = maybe_string_column; + this.strings_column = strings_column; + this.string_to_int_column = string_to_int_column; + this.complex_column = complex_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return bool_column; + case 1: return int_column; + case 2: return long_column; + case 3: return float_column; + case 4: return double_column; + case 5: return binary_column; + case 6: return string_column; + case 7: return maybe_bool_column; + case 8: return maybe_int_column; + case 9: return maybe_long_column; + case 10: return maybe_float_column; + case 11: return maybe_double_column; + case 12: return maybe_binary_column; + case 13: return maybe_string_column; + case 14: return strings_column; + case 15: return string_to_int_column; + case 16: return complex_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: bool_column = (java.lang.Boolean)value$; break; + case 1: int_column = (java.lang.Integer)value$; break; + case 2: long_column = (java.lang.Long)value$; break; + case 3: float_column = (java.lang.Float)value$; break; + case 4: double_column = (java.lang.Double)value$; break; + case 5: binary_column = (java.nio.ByteBuffer)value$; break; + case 6: string_column = (java.lang.String)value$; break; + case 7: maybe_bool_column = (java.lang.Boolean)value$; break; + case 8: maybe_int_column = (java.lang.Integer)value$; break; + case 9: maybe_long_column = (java.lang.Long)value$; break; + case 10: maybe_float_column = (java.lang.Float)value$; break; + case 11: maybe_double_column = (java.lang.Double)value$; break; + case 12: maybe_binary_column = (java.nio.ByteBuffer)value$; break; + case 13: maybe_string_column = (java.lang.String)value$; break; + case 14: strings_column = (java.util.List)value$; break; + case 15: string_to_int_column = (java.util.Map)value$; break; + case 16: complex_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'bool_column' field. + */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** + * Sets the value of the 'bool_column' field. + * @param value the value to set. + */ + public void setBoolColumn(java.lang.Boolean value) { + this.bool_column = value; + } + + /** + * Gets the value of the 'int_column' field. + */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** + * Sets the value of the 'int_column' field. + * @param value the value to set. + */ + public void setIntColumn(java.lang.Integer value) { + this.int_column = value; + } + + /** + * Gets the value of the 'long_column' field. + */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** + * Sets the value of the 'long_column' field. + * @param value the value to set. + */ + public void setLongColumn(java.lang.Long value) { + this.long_column = value; + } + + /** + * Gets the value of the 'float_column' field. + */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** + * Sets the value of the 'float_column' field. + * @param value the value to set. + */ + public void setFloatColumn(java.lang.Float value) { + this.float_column = value; + } + + /** + * Gets the value of the 'double_column' field. + */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** + * Sets the value of the 'double_column' field. + * @param value the value to set. + */ + public void setDoubleColumn(java.lang.Double value) { + this.double_column = value; + } + + /** + * Gets the value of the 'binary_column' field. + */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** + * Sets the value of the 'binary_column' field. + * @param value the value to set. + */ + public void setBinaryColumn(java.nio.ByteBuffer value) { + this.binary_column = value; + } + + /** + * Gets the value of the 'string_column' field. + */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** + * Sets the value of the 'string_column' field. + * @param value the value to set. + */ + public void setStringColumn(java.lang.String value) { + this.string_column = value; + } + + /** + * Gets the value of the 'maybe_bool_column' field. + */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** + * Sets the value of the 'maybe_bool_column' field. + * @param value the value to set. + */ + public void setMaybeBoolColumn(java.lang.Boolean value) { + this.maybe_bool_column = value; + } + + /** + * Gets the value of the 'maybe_int_column' field. + */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** + * Sets the value of the 'maybe_int_column' field. + * @param value the value to set. + */ + public void setMaybeIntColumn(java.lang.Integer value) { + this.maybe_int_column = value; + } + + /** + * Gets the value of the 'maybe_long_column' field. + */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** + * Sets the value of the 'maybe_long_column' field. + * @param value the value to set. + */ + public void setMaybeLongColumn(java.lang.Long value) { + this.maybe_long_column = value; + } + + /** + * Gets the value of the 'maybe_float_column' field. + */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** + * Sets the value of the 'maybe_float_column' field. + * @param value the value to set. + */ + public void setMaybeFloatColumn(java.lang.Float value) { + this.maybe_float_column = value; + } + + /** + * Gets the value of the 'maybe_double_column' field. + */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** + * Sets the value of the 'maybe_double_column' field. + * @param value the value to set. + */ + public void setMaybeDoubleColumn(java.lang.Double value) { + this.maybe_double_column = value; + } + + /** + * Gets the value of the 'maybe_binary_column' field. + */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** + * Sets the value of the 'maybe_binary_column' field. + * @param value the value to set. + */ + public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { + this.maybe_binary_column = value; + } + + /** + * Gets the value of the 'maybe_string_column' field. + */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** + * Sets the value of the 'maybe_string_column' field. + * @param value the value to set. + */ + public void setMaybeStringColumn(java.lang.String value) { + this.maybe_string_column = value; + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'string_to_int_column' field. + */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** + * Sets the value of the 'string_to_int_column' field. + * @param value the value to set. + */ + public void setStringToIntColumn(java.util.Map value) { + this.string_to_int_column = value; + } + + /** + * Gets the value of the 'complex_column' field. + */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** + * Sets the value of the 'complex_column' field. + * @param value the value to set. + */ + public void setComplexColumn(java.util.Map> value) { + this.complex_column = value; + } + + /** Creates a new ParquetAvroCompat RecordBuilder */ + public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { + return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { + return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ + public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { + return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** + * RecordBuilder for ParquetAvroCompat instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private boolean bool_column; + private int int_column; + private long long_column; + private float float_column; + private double double_column; + private java.nio.ByteBuffer binary_column; + private java.lang.String string_column; + private java.lang.Boolean maybe_bool_column; + private java.lang.Integer maybe_int_column; + private java.lang.Long maybe_long_column; + private java.lang.Float maybe_float_column; + private java.lang.Double maybe_double_column; + private java.nio.ByteBuffer maybe_binary_column; + private java.lang.String maybe_string_column; + private java.util.List strings_column; + private java.util.Map string_to_int_column; + private java.util.Map> complex_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { + super(other); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + if (isValidValue(fields()[7], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); + fieldSetFlags()[7] = true; + } + if (isValidValue(fields()[8], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); + fieldSetFlags()[8] = true; + } + if (isValidValue(fields()[9], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); + fieldSetFlags()[9] = true; + } + if (isValidValue(fields()[10], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); + fieldSetFlags()[10] = true; + } + if (isValidValue(fields()[11], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); + fieldSetFlags()[11] = true; + } + if (isValidValue(fields()[12], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); + fieldSetFlags()[12] = true; + } + if (isValidValue(fields()[13], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); + fieldSetFlags()[13] = true; + } + if (isValidValue(fields()[14], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); + fieldSetFlags()[14] = true; + } + if (isValidValue(fields()[15], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); + fieldSetFlags()[15] = true; + } + if (isValidValue(fields()[16], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); + fieldSetFlags()[16] = true; + } + } + + /** Creates a Builder by copying an existing ParquetAvroCompat instance */ + private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { + super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + if (isValidValue(fields()[7], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); + fieldSetFlags()[7] = true; + } + if (isValidValue(fields()[8], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); + fieldSetFlags()[8] = true; + } + if (isValidValue(fields()[9], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); + fieldSetFlags()[9] = true; + } + if (isValidValue(fields()[10], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); + fieldSetFlags()[10] = true; + } + if (isValidValue(fields()[11], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); + fieldSetFlags()[11] = true; + } + if (isValidValue(fields()[12], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); + fieldSetFlags()[12] = true; + } + if (isValidValue(fields()[13], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); + fieldSetFlags()[13] = true; + } + if (isValidValue(fields()[14], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); + fieldSetFlags()[14] = true; + } + if (isValidValue(fields()[15], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); + fieldSetFlags()[15] = true; + } + if (isValidValue(fields()[16], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); + fieldSetFlags()[16] = true; + } + } + + /** Gets the value of the 'bool_column' field */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** Sets the value of the 'bool_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { + validate(fields()[0], value); + this.bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'bool_column' field has been set */ + public boolean hasBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'bool_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'int_column' field */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** Sets the value of the 'int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { + validate(fields()[1], value); + this.int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'int_column' field has been set */ + public boolean hasIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'long_column' field */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** Sets the value of the 'long_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { + validate(fields()[2], value); + this.long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'long_column' field has been set */ + public boolean hasLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'long_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'float_column' field */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** Sets the value of the 'float_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { + validate(fields()[3], value); + this.float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'float_column' field has been set */ + public boolean hasFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'float_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'double_column' field */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** Sets the value of the 'double_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { + validate(fields()[4], value); + this.double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'double_column' field has been set */ + public boolean hasDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'double_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'binary_column' field */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** Sets the value of the 'binary_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'binary_column' field has been set */ + public boolean hasBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'binary_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { + binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'string_column' field */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** Sets the value of the 'string_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'string_column' field has been set */ + public boolean hasStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'string_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { + string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + /** Gets the value of the 'maybe_bool_column' field */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** Sets the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { + validate(fields()[7], value); + this.maybe_bool_column = value; + fieldSetFlags()[7] = true; + return this; + } + + /** Checks whether the 'maybe_bool_column' field has been set */ + public boolean hasMaybeBoolColumn() { + return fieldSetFlags()[7]; + } + + /** Clears the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { + maybe_bool_column = null; + fieldSetFlags()[7] = false; + return this; + } + + /** Gets the value of the 'maybe_int_column' field */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** Sets the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { + validate(fields()[8], value); + this.maybe_int_column = value; + fieldSetFlags()[8] = true; + return this; + } + + /** Checks whether the 'maybe_int_column' field has been set */ + public boolean hasMaybeIntColumn() { + return fieldSetFlags()[8]; + } + + /** Clears the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { + maybe_int_column = null; + fieldSetFlags()[8] = false; + return this; + } + + /** Gets the value of the 'maybe_long_column' field */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** Sets the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { + validate(fields()[9], value); + this.maybe_long_column = value; + fieldSetFlags()[9] = true; + return this; + } + + /** Checks whether the 'maybe_long_column' field has been set */ + public boolean hasMaybeLongColumn() { + return fieldSetFlags()[9]; + } + + /** Clears the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { + maybe_long_column = null; + fieldSetFlags()[9] = false; + return this; + } + + /** Gets the value of the 'maybe_float_column' field */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** Sets the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { + validate(fields()[10], value); + this.maybe_float_column = value; + fieldSetFlags()[10] = true; + return this; + } + + /** Checks whether the 'maybe_float_column' field has been set */ + public boolean hasMaybeFloatColumn() { + return fieldSetFlags()[10]; + } + + /** Clears the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { + maybe_float_column = null; + fieldSetFlags()[10] = false; + return this; + } + + /** Gets the value of the 'maybe_double_column' field */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** Sets the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { + validate(fields()[11], value); + this.maybe_double_column = value; + fieldSetFlags()[11] = true; + return this; + } + + /** Checks whether the 'maybe_double_column' field has been set */ + public boolean hasMaybeDoubleColumn() { + return fieldSetFlags()[11]; + } + + /** Clears the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { + maybe_double_column = null; + fieldSetFlags()[11] = false; + return this; + } + + /** Gets the value of the 'maybe_binary_column' field */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** Sets the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[12], value); + this.maybe_binary_column = value; + fieldSetFlags()[12] = true; + return this; + } + + /** Checks whether the 'maybe_binary_column' field has been set */ + public boolean hasMaybeBinaryColumn() { + return fieldSetFlags()[12]; + } + + /** Clears the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { + maybe_binary_column = null; + fieldSetFlags()[12] = false; + return this; + } + + /** Gets the value of the 'maybe_string_column' field */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** Sets the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { + validate(fields()[13], value); + this.maybe_string_column = value; + fieldSetFlags()[13] = true; + return this; + } + + /** Checks whether the 'maybe_string_column' field has been set */ + public boolean hasMaybeStringColumn() { + return fieldSetFlags()[13]; + } + + /** Clears the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { + maybe_string_column = null; + fieldSetFlags()[13] = false; + return this; + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { + validate(fields()[14], value); + this.strings_column = value; + fieldSetFlags()[14] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[14]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[14] = false; + return this; + } + + /** Gets the value of the 'string_to_int_column' field */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** Sets the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { + validate(fields()[15], value); + this.string_to_int_column = value; + fieldSetFlags()[15] = true; + return this; + } + + /** Checks whether the 'string_to_int_column' field has been set */ + public boolean hasStringToIntColumn() { + return fieldSetFlags()[15]; + } + + /** Clears the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { + string_to_int_column = null; + fieldSetFlags()[15] = false; + return this; + } + + /** Gets the value of the 'complex_column' field */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** Sets the value of the 'complex_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { + validate(fields()[16], value); + this.complex_column = value; + fieldSetFlags()[16] = true; + return this; + } + + /** Checks whether the 'complex_column' field has been set */ + public boolean hasComplexColumn() { + return fieldSetFlags()[16]; + } + + /** Clears the value of the 'complex_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { + complex_column = null; + fieldSetFlags()[16] = false; + return this; + } + + @Override + public ParquetAvroCompat build() { + try { + ParquetAvroCompat record = new ParquetAvroCompat(); + record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); + record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); + record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); + record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); + record.maybe_bool_column = fieldSetFlags()[7] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[7]); + record.maybe_int_column = fieldSetFlags()[8] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[8]); + record.maybe_long_column = fieldSetFlags()[9] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[9]); + record.maybe_float_column = fieldSetFlags()[10] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[10]); + record.maybe_double_column = fieldSetFlags()[11] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[11]); + record.maybe_binary_column = fieldSetFlags()[12] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[12]); + record.maybe_string_column = fieldSetFlags()[13] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[13]); + record.strings_column = fieldSetFlags()[14] ? this.strings_column : (java.util.List) defaultValue(fields()[14]); + record.string_to_int_column = fieldSetFlags()[15] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[15]); + record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Nested.java new file mode 100644 index 0000000000000..281e60cc3ae34 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Nested.java @@ -0,0 +1,541 @@ +/** + * Autogenerated by Thrift Compiler (0.9.2) + * + * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + * @generated + */ +package org.apache.spark.sql.parquet.test.thrift; + +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.protocol.TProtocolException; +import org.apache.thrift.EncodingUtils; +import org.apache.thrift.TException; +import org.apache.thrift.async.AsyncMethodCallback; +import org.apache.thrift.server.AbstractNonblockingServer.*; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.Set; +import java.util.HashSet; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; +import java.nio.ByteBuffer; +import java.util.Arrays; +import javax.annotation.Generated; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings({"cast", "rawtypes", "serial", "unchecked"}) +@Generated(value = "Autogenerated by Thrift Compiler (0.9.2)", date = "2015-7-7") +public class Nested implements org.apache.thrift.TBase, java.io.Serializable, Cloneable, Comparable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Nested"); + + private static final org.apache.thrift.protocol.TField NESTED_INTS_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("nestedIntsColumn", org.apache.thrift.protocol.TType.LIST, (short)1); + private static final org.apache.thrift.protocol.TField NESTED_STRING_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("nestedStringColumn", org.apache.thrift.protocol.TType.STRING, (short)2); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new NestedStandardSchemeFactory()); + schemes.put(TupleScheme.class, new NestedTupleSchemeFactory()); + } + + public List nestedIntsColumn; // required + public String nestedStringColumn; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + NESTED_INTS_COLUMN((short)1, "nestedIntsColumn"), + NESTED_STRING_COLUMN((short)2, "nestedStringColumn"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // NESTED_INTS_COLUMN + return NESTED_INTS_COLUMN; + case 2: // NESTED_STRING_COLUMN + return NESTED_STRING_COLUMN; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.NESTED_INTS_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("nestedIntsColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)))); + tmpMap.put(_Fields.NESTED_STRING_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("nestedStringColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(Nested.class, metaDataMap); + } + + public Nested() { + } + + public Nested( + List nestedIntsColumn, + String nestedStringColumn) + { + this(); + this.nestedIntsColumn = nestedIntsColumn; + this.nestedStringColumn = nestedStringColumn; + } + + /** + * Performs a deep copy on other. + */ + public Nested(Nested other) { + if (other.isSetNestedIntsColumn()) { + List __this__nestedIntsColumn = new ArrayList(other.nestedIntsColumn); + this.nestedIntsColumn = __this__nestedIntsColumn; + } + if (other.isSetNestedStringColumn()) { + this.nestedStringColumn = other.nestedStringColumn; + } + } + + public Nested deepCopy() { + return new Nested(this); + } + + @Override + public void clear() { + this.nestedIntsColumn = null; + this.nestedStringColumn = null; + } + + public int getNestedIntsColumnSize() { + return (this.nestedIntsColumn == null) ? 0 : this.nestedIntsColumn.size(); + } + + public java.util.Iterator getNestedIntsColumnIterator() { + return (this.nestedIntsColumn == null) ? null : this.nestedIntsColumn.iterator(); + } + + public void addToNestedIntsColumn(int elem) { + if (this.nestedIntsColumn == null) { + this.nestedIntsColumn = new ArrayList(); + } + this.nestedIntsColumn.add(elem); + } + + public List getNestedIntsColumn() { + return this.nestedIntsColumn; + } + + public Nested setNestedIntsColumn(List nestedIntsColumn) { + this.nestedIntsColumn = nestedIntsColumn; + return this; + } + + public void unsetNestedIntsColumn() { + this.nestedIntsColumn = null; + } + + /** Returns true if field nestedIntsColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetNestedIntsColumn() { + return this.nestedIntsColumn != null; + } + + public void setNestedIntsColumnIsSet(boolean value) { + if (!value) { + this.nestedIntsColumn = null; + } + } + + public String getNestedStringColumn() { + return this.nestedStringColumn; + } + + public Nested setNestedStringColumn(String nestedStringColumn) { + this.nestedStringColumn = nestedStringColumn; + return this; + } + + public void unsetNestedStringColumn() { + this.nestedStringColumn = null; + } + + /** Returns true if field nestedStringColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetNestedStringColumn() { + return this.nestedStringColumn != null; + } + + public void setNestedStringColumnIsSet(boolean value) { + if (!value) { + this.nestedStringColumn = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case NESTED_INTS_COLUMN: + if (value == null) { + unsetNestedIntsColumn(); + } else { + setNestedIntsColumn((List)value); + } + break; + + case NESTED_STRING_COLUMN: + if (value == null) { + unsetNestedStringColumn(); + } else { + setNestedStringColumn((String)value); + } + break; + + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case NESTED_INTS_COLUMN: + return getNestedIntsColumn(); + + case NESTED_STRING_COLUMN: + return getNestedStringColumn(); + + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case NESTED_INTS_COLUMN: + return isSetNestedIntsColumn(); + case NESTED_STRING_COLUMN: + return isSetNestedStringColumn(); + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) + return false; + if (that instanceof Nested) + return this.equals((Nested)that); + return false; + } + + public boolean equals(Nested that) { + if (that == null) + return false; + + boolean this_present_nestedIntsColumn = true && this.isSetNestedIntsColumn(); + boolean that_present_nestedIntsColumn = true && that.isSetNestedIntsColumn(); + if (this_present_nestedIntsColumn || that_present_nestedIntsColumn) { + if (!(this_present_nestedIntsColumn && that_present_nestedIntsColumn)) + return false; + if (!this.nestedIntsColumn.equals(that.nestedIntsColumn)) + return false; + } + + boolean this_present_nestedStringColumn = true && this.isSetNestedStringColumn(); + boolean that_present_nestedStringColumn = true && that.isSetNestedStringColumn(); + if (this_present_nestedStringColumn || that_present_nestedStringColumn) { + if (!(this_present_nestedStringColumn && that_present_nestedStringColumn)) + return false; + if (!this.nestedStringColumn.equals(that.nestedStringColumn)) + return false; + } + + return true; + } + + @Override + public int hashCode() { + List list = new ArrayList(); + + boolean present_nestedIntsColumn = true && (isSetNestedIntsColumn()); + list.add(present_nestedIntsColumn); + if (present_nestedIntsColumn) + list.add(nestedIntsColumn); + + boolean present_nestedStringColumn = true && (isSetNestedStringColumn()); + list.add(present_nestedStringColumn); + if (present_nestedStringColumn) + list.add(nestedStringColumn); + + return list.hashCode(); + } + + @Override + public int compareTo(Nested other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + + lastComparison = Boolean.valueOf(isSetNestedIntsColumn()).compareTo(other.isSetNestedIntsColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetNestedIntsColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.nestedIntsColumn, other.nestedIntsColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetNestedStringColumn()).compareTo(other.isSetNestedStringColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetNestedStringColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.nestedStringColumn, other.nestedStringColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("Nested("); + boolean first = true; + + sb.append("nestedIntsColumn:"); + if (this.nestedIntsColumn == null) { + sb.append("null"); + } else { + sb.append(this.nestedIntsColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("nestedStringColumn:"); + if (this.nestedStringColumn == null) { + sb.append("null"); + } else { + sb.append(this.nestedStringColumn); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + if (nestedIntsColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'nestedIntsColumn' was not present! Struct: " + toString()); + } + if (nestedStringColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'nestedStringColumn' was not present! Struct: " + toString()); + } + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class NestedStandardSchemeFactory implements SchemeFactory { + public NestedStandardScheme getScheme() { + return new NestedStandardScheme(); + } + } + + private static class NestedStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, Nested struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // NESTED_INTS_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list0 = iprot.readListBegin(); + struct.nestedIntsColumn = new ArrayList(_list0.size); + int _elem1; + for (int _i2 = 0; _i2 < _list0.size; ++_i2) + { + _elem1 = iprot.readI32(); + struct.nestedIntsColumn.add(_elem1); + } + iprot.readListEnd(); + } + struct.setNestedIntsColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 2: // NESTED_STRING_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.nestedStringColumn = iprot.readString(); + struct.setNestedStringColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + + // check for required fields of primitive type, which can't be checked in the validate method + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, Nested struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + if (struct.nestedIntsColumn != null) { + oprot.writeFieldBegin(NESTED_INTS_COLUMN_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, struct.nestedIntsColumn.size())); + for (int _iter3 : struct.nestedIntsColumn) + { + oprot.writeI32(_iter3); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.nestedStringColumn != null) { + oprot.writeFieldBegin(NESTED_STRING_COLUMN_FIELD_DESC); + oprot.writeString(struct.nestedStringColumn); + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class NestedTupleSchemeFactory implements SchemeFactory { + public NestedTupleScheme getScheme() { + return new NestedTupleScheme(); + } + } + + private static class NestedTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, Nested struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + { + oprot.writeI32(struct.nestedIntsColumn.size()); + for (int _iter4 : struct.nestedIntsColumn) + { + oprot.writeI32(_iter4); + } + } + oprot.writeString(struct.nestedStringColumn); + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, Nested struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + { + org.apache.thrift.protocol.TList _list5 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, iprot.readI32()); + struct.nestedIntsColumn = new ArrayList(_list5.size); + int _elem6; + for (int _i7 = 0; _i7 < _list5.size; ++_i7) + { + _elem6 = iprot.readI32(); + struct.nestedIntsColumn.add(_elem6); + } + } + struct.setNestedIntsColumnIsSet(true); + struct.nestedStringColumn = iprot.readString(); + struct.setNestedStringColumnIsSet(true); + } + } + +} + diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/ParquetThriftCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/ParquetThriftCompat.java new file mode 100644 index 0000000000000..326ae9dbaa0d1 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/ParquetThriftCompat.java @@ -0,0 +1,2808 @@ +/** + * Autogenerated by Thrift Compiler (0.9.2) + * + * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + * @generated + */ +package org.apache.spark.sql.parquet.test.thrift; + +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.protocol.TProtocolException; +import org.apache.thrift.EncodingUtils; +import org.apache.thrift.TException; +import org.apache.thrift.async.AsyncMethodCallback; +import org.apache.thrift.server.AbstractNonblockingServer.*; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.Set; +import java.util.HashSet; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; +import java.nio.ByteBuffer; +import java.util.Arrays; +import javax.annotation.Generated; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings({"cast", "rawtypes", "serial", "unchecked"}) +/** + * This is a test struct for testing parquet-thrift compatibility. + */ +@Generated(value = "Autogenerated by Thrift Compiler (0.9.2)", date = "2015-7-7") +public class ParquetThriftCompat implements org.apache.thrift.TBase, java.io.Serializable, Cloneable, Comparable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("ParquetThriftCompat"); + + private static final org.apache.thrift.protocol.TField BOOL_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("boolColumn", org.apache.thrift.protocol.TType.BOOL, (short)1); + private static final org.apache.thrift.protocol.TField BYTE_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("byteColumn", org.apache.thrift.protocol.TType.BYTE, (short)2); + private static final org.apache.thrift.protocol.TField SHORT_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("shortColumn", org.apache.thrift.protocol.TType.I16, (short)3); + private static final org.apache.thrift.protocol.TField INT_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("intColumn", org.apache.thrift.protocol.TType.I32, (short)4); + private static final org.apache.thrift.protocol.TField LONG_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("longColumn", org.apache.thrift.protocol.TType.I64, (short)5); + private static final org.apache.thrift.protocol.TField DOUBLE_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("doubleColumn", org.apache.thrift.protocol.TType.DOUBLE, (short)6); + private static final org.apache.thrift.protocol.TField BINARY_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("binaryColumn", org.apache.thrift.protocol.TType.STRING, (short)7); + private static final org.apache.thrift.protocol.TField STRING_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("stringColumn", org.apache.thrift.protocol.TType.STRING, (short)8); + private static final org.apache.thrift.protocol.TField ENUM_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("enumColumn", org.apache.thrift.protocol.TType.I32, (short)9); + private static final org.apache.thrift.protocol.TField MAYBE_BOOL_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeBoolColumn", org.apache.thrift.protocol.TType.BOOL, (short)10); + private static final org.apache.thrift.protocol.TField MAYBE_BYTE_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeByteColumn", org.apache.thrift.protocol.TType.BYTE, (short)11); + private static final org.apache.thrift.protocol.TField MAYBE_SHORT_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeShortColumn", org.apache.thrift.protocol.TType.I16, (short)12); + private static final org.apache.thrift.protocol.TField MAYBE_INT_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeIntColumn", org.apache.thrift.protocol.TType.I32, (short)13); + private static final org.apache.thrift.protocol.TField MAYBE_LONG_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeLongColumn", org.apache.thrift.protocol.TType.I64, (short)14); + private static final org.apache.thrift.protocol.TField MAYBE_DOUBLE_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeDoubleColumn", org.apache.thrift.protocol.TType.DOUBLE, (short)15); + private static final org.apache.thrift.protocol.TField MAYBE_BINARY_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeBinaryColumn", org.apache.thrift.protocol.TType.STRING, (short)16); + private static final org.apache.thrift.protocol.TField MAYBE_STRING_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeStringColumn", org.apache.thrift.protocol.TType.STRING, (short)17); + private static final org.apache.thrift.protocol.TField MAYBE_ENUM_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeEnumColumn", org.apache.thrift.protocol.TType.I32, (short)18); + private static final org.apache.thrift.protocol.TField STRINGS_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("stringsColumn", org.apache.thrift.protocol.TType.LIST, (short)19); + private static final org.apache.thrift.protocol.TField INT_SET_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("intSetColumn", org.apache.thrift.protocol.TType.SET, (short)20); + private static final org.apache.thrift.protocol.TField INT_TO_STRING_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("intToStringColumn", org.apache.thrift.protocol.TType.MAP, (short)21); + private static final org.apache.thrift.protocol.TField COMPLEX_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("complexColumn", org.apache.thrift.protocol.TType.MAP, (short)22); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new ParquetThriftCompatStandardSchemeFactory()); + schemes.put(TupleScheme.class, new ParquetThriftCompatTupleSchemeFactory()); + } + + public boolean boolColumn; // required + public byte byteColumn; // required + public short shortColumn; // required + public int intColumn; // required + public long longColumn; // required + public double doubleColumn; // required + public ByteBuffer binaryColumn; // required + public String stringColumn; // required + /** + * + * @see Suit + */ + public Suit enumColumn; // required + public boolean maybeBoolColumn; // optional + public byte maybeByteColumn; // optional + public short maybeShortColumn; // optional + public int maybeIntColumn; // optional + public long maybeLongColumn; // optional + public double maybeDoubleColumn; // optional + public ByteBuffer maybeBinaryColumn; // optional + public String maybeStringColumn; // optional + /** + * + * @see Suit + */ + public Suit maybeEnumColumn; // optional + public List stringsColumn; // required + public Set intSetColumn; // required + public Map intToStringColumn; // required + public Map> complexColumn; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + BOOL_COLUMN((short)1, "boolColumn"), + BYTE_COLUMN((short)2, "byteColumn"), + SHORT_COLUMN((short)3, "shortColumn"), + INT_COLUMN((short)4, "intColumn"), + LONG_COLUMN((short)5, "longColumn"), + DOUBLE_COLUMN((short)6, "doubleColumn"), + BINARY_COLUMN((short)7, "binaryColumn"), + STRING_COLUMN((short)8, "stringColumn"), + /** + * + * @see Suit + */ + ENUM_COLUMN((short)9, "enumColumn"), + MAYBE_BOOL_COLUMN((short)10, "maybeBoolColumn"), + MAYBE_BYTE_COLUMN((short)11, "maybeByteColumn"), + MAYBE_SHORT_COLUMN((short)12, "maybeShortColumn"), + MAYBE_INT_COLUMN((short)13, "maybeIntColumn"), + MAYBE_LONG_COLUMN((short)14, "maybeLongColumn"), + MAYBE_DOUBLE_COLUMN((short)15, "maybeDoubleColumn"), + MAYBE_BINARY_COLUMN((short)16, "maybeBinaryColumn"), + MAYBE_STRING_COLUMN((short)17, "maybeStringColumn"), + /** + * + * @see Suit + */ + MAYBE_ENUM_COLUMN((short)18, "maybeEnumColumn"), + STRINGS_COLUMN((short)19, "stringsColumn"), + INT_SET_COLUMN((short)20, "intSetColumn"), + INT_TO_STRING_COLUMN((short)21, "intToStringColumn"), + COMPLEX_COLUMN((short)22, "complexColumn"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // BOOL_COLUMN + return BOOL_COLUMN; + case 2: // BYTE_COLUMN + return BYTE_COLUMN; + case 3: // SHORT_COLUMN + return SHORT_COLUMN; + case 4: // INT_COLUMN + return INT_COLUMN; + case 5: // LONG_COLUMN + return LONG_COLUMN; + case 6: // DOUBLE_COLUMN + return DOUBLE_COLUMN; + case 7: // BINARY_COLUMN + return BINARY_COLUMN; + case 8: // STRING_COLUMN + return STRING_COLUMN; + case 9: // ENUM_COLUMN + return ENUM_COLUMN; + case 10: // MAYBE_BOOL_COLUMN + return MAYBE_BOOL_COLUMN; + case 11: // MAYBE_BYTE_COLUMN + return MAYBE_BYTE_COLUMN; + case 12: // MAYBE_SHORT_COLUMN + return MAYBE_SHORT_COLUMN; + case 13: // MAYBE_INT_COLUMN + return MAYBE_INT_COLUMN; + case 14: // MAYBE_LONG_COLUMN + return MAYBE_LONG_COLUMN; + case 15: // MAYBE_DOUBLE_COLUMN + return MAYBE_DOUBLE_COLUMN; + case 16: // MAYBE_BINARY_COLUMN + return MAYBE_BINARY_COLUMN; + case 17: // MAYBE_STRING_COLUMN + return MAYBE_STRING_COLUMN; + case 18: // MAYBE_ENUM_COLUMN + return MAYBE_ENUM_COLUMN; + case 19: // STRINGS_COLUMN + return STRINGS_COLUMN; + case 20: // INT_SET_COLUMN + return INT_SET_COLUMN; + case 21: // INT_TO_STRING_COLUMN + return INT_TO_STRING_COLUMN; + case 22: // COMPLEX_COLUMN + return COMPLEX_COLUMN; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + private static final int __BOOLCOLUMN_ISSET_ID = 0; + private static final int __BYTECOLUMN_ISSET_ID = 1; + private static final int __SHORTCOLUMN_ISSET_ID = 2; + private static final int __INTCOLUMN_ISSET_ID = 3; + private static final int __LONGCOLUMN_ISSET_ID = 4; + private static final int __DOUBLECOLUMN_ISSET_ID = 5; + private static final int __MAYBEBOOLCOLUMN_ISSET_ID = 6; + private static final int __MAYBEBYTECOLUMN_ISSET_ID = 7; + private static final int __MAYBESHORTCOLUMN_ISSET_ID = 8; + private static final int __MAYBEINTCOLUMN_ISSET_ID = 9; + private static final int __MAYBELONGCOLUMN_ISSET_ID = 10; + private static final int __MAYBEDOUBLECOLUMN_ISSET_ID = 11; + private short __isset_bitfield = 0; + private static final _Fields optionals[] = {_Fields.MAYBE_BOOL_COLUMN,_Fields.MAYBE_BYTE_COLUMN,_Fields.MAYBE_SHORT_COLUMN,_Fields.MAYBE_INT_COLUMN,_Fields.MAYBE_LONG_COLUMN,_Fields.MAYBE_DOUBLE_COLUMN,_Fields.MAYBE_BINARY_COLUMN,_Fields.MAYBE_STRING_COLUMN,_Fields.MAYBE_ENUM_COLUMN}; + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.BOOL_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("boolColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.BOOL))); + tmpMap.put(_Fields.BYTE_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("byteColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.BYTE))); + tmpMap.put(_Fields.SHORT_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("shortColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I16))); + tmpMap.put(_Fields.INT_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("intColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); + tmpMap.put(_Fields.LONG_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("longColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I64))); + tmpMap.put(_Fields.DOUBLE_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("doubleColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.DOUBLE))); + tmpMap.put(_Fields.BINARY_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("binaryColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING , true))); + tmpMap.put(_Fields.STRING_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("stringColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + tmpMap.put(_Fields.ENUM_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("enumColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.EnumMetaData(org.apache.thrift.protocol.TType.ENUM, Suit.class))); + tmpMap.put(_Fields.MAYBE_BOOL_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeBoolColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.BOOL))); + tmpMap.put(_Fields.MAYBE_BYTE_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeByteColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.BYTE))); + tmpMap.put(_Fields.MAYBE_SHORT_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeShortColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I16))); + tmpMap.put(_Fields.MAYBE_INT_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeIntColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); + tmpMap.put(_Fields.MAYBE_LONG_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeLongColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I64))); + tmpMap.put(_Fields.MAYBE_DOUBLE_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeDoubleColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.DOUBLE))); + tmpMap.put(_Fields.MAYBE_BINARY_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeBinaryColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING , true))); + tmpMap.put(_Fields.MAYBE_STRING_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeStringColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + tmpMap.put(_Fields.MAYBE_ENUM_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeEnumColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.EnumMetaData(org.apache.thrift.protocol.TType.ENUM, Suit.class))); + tmpMap.put(_Fields.STRINGS_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("stringsColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + tmpMap.put(_Fields.INT_SET_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("intSetColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.SetMetaData(org.apache.thrift.protocol.TType.SET, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)))); + tmpMap.put(_Fields.INT_TO_STRING_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("intToStringColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.MapMetaData(org.apache.thrift.protocol.TType.MAP, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32), + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + tmpMap.put(_Fields.COMPLEX_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("complexColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.MapMetaData(org.apache.thrift.protocol.TType.MAP, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32), + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.StructMetaData(org.apache.thrift.protocol.TType.STRUCT, Nested.class))))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(ParquetThriftCompat.class, metaDataMap); + } + + public ParquetThriftCompat() { + } + + public ParquetThriftCompat( + boolean boolColumn, + byte byteColumn, + short shortColumn, + int intColumn, + long longColumn, + double doubleColumn, + ByteBuffer binaryColumn, + String stringColumn, + Suit enumColumn, + List stringsColumn, + Set intSetColumn, + Map intToStringColumn, + Map> complexColumn) + { + this(); + this.boolColumn = boolColumn; + setBoolColumnIsSet(true); + this.byteColumn = byteColumn; + setByteColumnIsSet(true); + this.shortColumn = shortColumn; + setShortColumnIsSet(true); + this.intColumn = intColumn; + setIntColumnIsSet(true); + this.longColumn = longColumn; + setLongColumnIsSet(true); + this.doubleColumn = doubleColumn; + setDoubleColumnIsSet(true); + this.binaryColumn = org.apache.thrift.TBaseHelper.copyBinary(binaryColumn); + this.stringColumn = stringColumn; + this.enumColumn = enumColumn; + this.stringsColumn = stringsColumn; + this.intSetColumn = intSetColumn; + this.intToStringColumn = intToStringColumn; + this.complexColumn = complexColumn; + } + + /** + * Performs a deep copy on other. + */ + public ParquetThriftCompat(ParquetThriftCompat other) { + __isset_bitfield = other.__isset_bitfield; + this.boolColumn = other.boolColumn; + this.byteColumn = other.byteColumn; + this.shortColumn = other.shortColumn; + this.intColumn = other.intColumn; + this.longColumn = other.longColumn; + this.doubleColumn = other.doubleColumn; + if (other.isSetBinaryColumn()) { + this.binaryColumn = org.apache.thrift.TBaseHelper.copyBinary(other.binaryColumn); + } + if (other.isSetStringColumn()) { + this.stringColumn = other.stringColumn; + } + if (other.isSetEnumColumn()) { + this.enumColumn = other.enumColumn; + } + this.maybeBoolColumn = other.maybeBoolColumn; + this.maybeByteColumn = other.maybeByteColumn; + this.maybeShortColumn = other.maybeShortColumn; + this.maybeIntColumn = other.maybeIntColumn; + this.maybeLongColumn = other.maybeLongColumn; + this.maybeDoubleColumn = other.maybeDoubleColumn; + if (other.isSetMaybeBinaryColumn()) { + this.maybeBinaryColumn = org.apache.thrift.TBaseHelper.copyBinary(other.maybeBinaryColumn); + } + if (other.isSetMaybeStringColumn()) { + this.maybeStringColumn = other.maybeStringColumn; + } + if (other.isSetMaybeEnumColumn()) { + this.maybeEnumColumn = other.maybeEnumColumn; + } + if (other.isSetStringsColumn()) { + List __this__stringsColumn = new ArrayList(other.stringsColumn); + this.stringsColumn = __this__stringsColumn; + } + if (other.isSetIntSetColumn()) { + Set __this__intSetColumn = new HashSet(other.intSetColumn); + this.intSetColumn = __this__intSetColumn; + } + if (other.isSetIntToStringColumn()) { + Map __this__intToStringColumn = new HashMap(other.intToStringColumn); + this.intToStringColumn = __this__intToStringColumn; + } + if (other.isSetComplexColumn()) { + Map> __this__complexColumn = new HashMap>(other.complexColumn.size()); + for (Map.Entry> other_element : other.complexColumn.entrySet()) { + + Integer other_element_key = other_element.getKey(); + List other_element_value = other_element.getValue(); + + Integer __this__complexColumn_copy_key = other_element_key; + + List __this__complexColumn_copy_value = new ArrayList(other_element_value.size()); + for (Nested other_element_value_element : other_element_value) { + __this__complexColumn_copy_value.add(new Nested(other_element_value_element)); + } + + __this__complexColumn.put(__this__complexColumn_copy_key, __this__complexColumn_copy_value); + } + this.complexColumn = __this__complexColumn; + } + } + + public ParquetThriftCompat deepCopy() { + return new ParquetThriftCompat(this); + } + + @Override + public void clear() { + setBoolColumnIsSet(false); + this.boolColumn = false; + setByteColumnIsSet(false); + this.byteColumn = 0; + setShortColumnIsSet(false); + this.shortColumn = 0; + setIntColumnIsSet(false); + this.intColumn = 0; + setLongColumnIsSet(false); + this.longColumn = 0; + setDoubleColumnIsSet(false); + this.doubleColumn = 0.0; + this.binaryColumn = null; + this.stringColumn = null; + this.enumColumn = null; + setMaybeBoolColumnIsSet(false); + this.maybeBoolColumn = false; + setMaybeByteColumnIsSet(false); + this.maybeByteColumn = 0; + setMaybeShortColumnIsSet(false); + this.maybeShortColumn = 0; + setMaybeIntColumnIsSet(false); + this.maybeIntColumn = 0; + setMaybeLongColumnIsSet(false); + this.maybeLongColumn = 0; + setMaybeDoubleColumnIsSet(false); + this.maybeDoubleColumn = 0.0; + this.maybeBinaryColumn = null; + this.maybeStringColumn = null; + this.maybeEnumColumn = null; + this.stringsColumn = null; + this.intSetColumn = null; + this.intToStringColumn = null; + this.complexColumn = null; + } + + public boolean isBoolColumn() { + return this.boolColumn; + } + + public ParquetThriftCompat setBoolColumn(boolean boolColumn) { + this.boolColumn = boolColumn; + setBoolColumnIsSet(true); + return this; + } + + public void unsetBoolColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __BOOLCOLUMN_ISSET_ID); + } + + /** Returns true if field boolColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetBoolColumn() { + return EncodingUtils.testBit(__isset_bitfield, __BOOLCOLUMN_ISSET_ID); + } + + public void setBoolColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __BOOLCOLUMN_ISSET_ID, value); + } + + public byte getByteColumn() { + return this.byteColumn; + } + + public ParquetThriftCompat setByteColumn(byte byteColumn) { + this.byteColumn = byteColumn; + setByteColumnIsSet(true); + return this; + } + + public void unsetByteColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __BYTECOLUMN_ISSET_ID); + } + + /** Returns true if field byteColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetByteColumn() { + return EncodingUtils.testBit(__isset_bitfield, __BYTECOLUMN_ISSET_ID); + } + + public void setByteColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __BYTECOLUMN_ISSET_ID, value); + } + + public short getShortColumn() { + return this.shortColumn; + } + + public ParquetThriftCompat setShortColumn(short shortColumn) { + this.shortColumn = shortColumn; + setShortColumnIsSet(true); + return this; + } + + public void unsetShortColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __SHORTCOLUMN_ISSET_ID); + } + + /** Returns true if field shortColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetShortColumn() { + return EncodingUtils.testBit(__isset_bitfield, __SHORTCOLUMN_ISSET_ID); + } + + public void setShortColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __SHORTCOLUMN_ISSET_ID, value); + } + + public int getIntColumn() { + return this.intColumn; + } + + public ParquetThriftCompat setIntColumn(int intColumn) { + this.intColumn = intColumn; + setIntColumnIsSet(true); + return this; + } + + public void unsetIntColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __INTCOLUMN_ISSET_ID); + } + + /** Returns true if field intColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetIntColumn() { + return EncodingUtils.testBit(__isset_bitfield, __INTCOLUMN_ISSET_ID); + } + + public void setIntColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __INTCOLUMN_ISSET_ID, value); + } + + public long getLongColumn() { + return this.longColumn; + } + + public ParquetThriftCompat setLongColumn(long longColumn) { + this.longColumn = longColumn; + setLongColumnIsSet(true); + return this; + } + + public void unsetLongColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __LONGCOLUMN_ISSET_ID); + } + + /** Returns true if field longColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetLongColumn() { + return EncodingUtils.testBit(__isset_bitfield, __LONGCOLUMN_ISSET_ID); + } + + public void setLongColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __LONGCOLUMN_ISSET_ID, value); + } + + public double getDoubleColumn() { + return this.doubleColumn; + } + + public ParquetThriftCompat setDoubleColumn(double doubleColumn) { + this.doubleColumn = doubleColumn; + setDoubleColumnIsSet(true); + return this; + } + + public void unsetDoubleColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __DOUBLECOLUMN_ISSET_ID); + } + + /** Returns true if field doubleColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetDoubleColumn() { + return EncodingUtils.testBit(__isset_bitfield, __DOUBLECOLUMN_ISSET_ID); + } + + public void setDoubleColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __DOUBLECOLUMN_ISSET_ID, value); + } + + public byte[] getBinaryColumn() { + setBinaryColumn(org.apache.thrift.TBaseHelper.rightSize(binaryColumn)); + return binaryColumn == null ? null : binaryColumn.array(); + } + + public ByteBuffer bufferForBinaryColumn() { + return org.apache.thrift.TBaseHelper.copyBinary(binaryColumn); + } + + public ParquetThriftCompat setBinaryColumn(byte[] binaryColumn) { + this.binaryColumn = binaryColumn == null ? (ByteBuffer)null : ByteBuffer.wrap(Arrays.copyOf(binaryColumn, binaryColumn.length)); + return this; + } + + public ParquetThriftCompat setBinaryColumn(ByteBuffer binaryColumn) { + this.binaryColumn = org.apache.thrift.TBaseHelper.copyBinary(binaryColumn); + return this; + } + + public void unsetBinaryColumn() { + this.binaryColumn = null; + } + + /** Returns true if field binaryColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetBinaryColumn() { + return this.binaryColumn != null; + } + + public void setBinaryColumnIsSet(boolean value) { + if (!value) { + this.binaryColumn = null; + } + } + + public String getStringColumn() { + return this.stringColumn; + } + + public ParquetThriftCompat setStringColumn(String stringColumn) { + this.stringColumn = stringColumn; + return this; + } + + public void unsetStringColumn() { + this.stringColumn = null; + } + + /** Returns true if field stringColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetStringColumn() { + return this.stringColumn != null; + } + + public void setStringColumnIsSet(boolean value) { + if (!value) { + this.stringColumn = null; + } + } + + /** + * + * @see Suit + */ + public Suit getEnumColumn() { + return this.enumColumn; + } + + /** + * + * @see Suit + */ + public ParquetThriftCompat setEnumColumn(Suit enumColumn) { + this.enumColumn = enumColumn; + return this; + } + + public void unsetEnumColumn() { + this.enumColumn = null; + } + + /** Returns true if field enumColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetEnumColumn() { + return this.enumColumn != null; + } + + public void setEnumColumnIsSet(boolean value) { + if (!value) { + this.enumColumn = null; + } + } + + public boolean isMaybeBoolColumn() { + return this.maybeBoolColumn; + } + + public ParquetThriftCompat setMaybeBoolColumn(boolean maybeBoolColumn) { + this.maybeBoolColumn = maybeBoolColumn; + setMaybeBoolColumnIsSet(true); + return this; + } + + public void unsetMaybeBoolColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBEBOOLCOLUMN_ISSET_ID); + } + + /** Returns true if field maybeBoolColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeBoolColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBEBOOLCOLUMN_ISSET_ID); + } + + public void setMaybeBoolColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBEBOOLCOLUMN_ISSET_ID, value); + } + + public byte getMaybeByteColumn() { + return this.maybeByteColumn; + } + + public ParquetThriftCompat setMaybeByteColumn(byte maybeByteColumn) { + this.maybeByteColumn = maybeByteColumn; + setMaybeByteColumnIsSet(true); + return this; + } + + public void unsetMaybeByteColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBEBYTECOLUMN_ISSET_ID); + } + + /** Returns true if field maybeByteColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeByteColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBEBYTECOLUMN_ISSET_ID); + } + + public void setMaybeByteColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBEBYTECOLUMN_ISSET_ID, value); + } + + public short getMaybeShortColumn() { + return this.maybeShortColumn; + } + + public ParquetThriftCompat setMaybeShortColumn(short maybeShortColumn) { + this.maybeShortColumn = maybeShortColumn; + setMaybeShortColumnIsSet(true); + return this; + } + + public void unsetMaybeShortColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBESHORTCOLUMN_ISSET_ID); + } + + /** Returns true if field maybeShortColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeShortColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBESHORTCOLUMN_ISSET_ID); + } + + public void setMaybeShortColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBESHORTCOLUMN_ISSET_ID, value); + } + + public int getMaybeIntColumn() { + return this.maybeIntColumn; + } + + public ParquetThriftCompat setMaybeIntColumn(int maybeIntColumn) { + this.maybeIntColumn = maybeIntColumn; + setMaybeIntColumnIsSet(true); + return this; + } + + public void unsetMaybeIntColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBEINTCOLUMN_ISSET_ID); + } + + /** Returns true if field maybeIntColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeIntColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBEINTCOLUMN_ISSET_ID); + } + + public void setMaybeIntColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBEINTCOLUMN_ISSET_ID, value); + } + + public long getMaybeLongColumn() { + return this.maybeLongColumn; + } + + public ParquetThriftCompat setMaybeLongColumn(long maybeLongColumn) { + this.maybeLongColumn = maybeLongColumn; + setMaybeLongColumnIsSet(true); + return this; + } + + public void unsetMaybeLongColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBELONGCOLUMN_ISSET_ID); + } + + /** Returns true if field maybeLongColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeLongColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBELONGCOLUMN_ISSET_ID); + } + + public void setMaybeLongColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBELONGCOLUMN_ISSET_ID, value); + } + + public double getMaybeDoubleColumn() { + return this.maybeDoubleColumn; + } + + public ParquetThriftCompat setMaybeDoubleColumn(double maybeDoubleColumn) { + this.maybeDoubleColumn = maybeDoubleColumn; + setMaybeDoubleColumnIsSet(true); + return this; + } + + public void unsetMaybeDoubleColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBEDOUBLECOLUMN_ISSET_ID); + } + + /** Returns true if field maybeDoubleColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeDoubleColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBEDOUBLECOLUMN_ISSET_ID); + } + + public void setMaybeDoubleColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBEDOUBLECOLUMN_ISSET_ID, value); + } + + public byte[] getMaybeBinaryColumn() { + setMaybeBinaryColumn(org.apache.thrift.TBaseHelper.rightSize(maybeBinaryColumn)); + return maybeBinaryColumn == null ? null : maybeBinaryColumn.array(); + } + + public ByteBuffer bufferForMaybeBinaryColumn() { + return org.apache.thrift.TBaseHelper.copyBinary(maybeBinaryColumn); + } + + public ParquetThriftCompat setMaybeBinaryColumn(byte[] maybeBinaryColumn) { + this.maybeBinaryColumn = maybeBinaryColumn == null ? (ByteBuffer)null : ByteBuffer.wrap(Arrays.copyOf(maybeBinaryColumn, maybeBinaryColumn.length)); + return this; + } + + public ParquetThriftCompat setMaybeBinaryColumn(ByteBuffer maybeBinaryColumn) { + this.maybeBinaryColumn = org.apache.thrift.TBaseHelper.copyBinary(maybeBinaryColumn); + return this; + } + + public void unsetMaybeBinaryColumn() { + this.maybeBinaryColumn = null; + } + + /** Returns true if field maybeBinaryColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeBinaryColumn() { + return this.maybeBinaryColumn != null; + } + + public void setMaybeBinaryColumnIsSet(boolean value) { + if (!value) { + this.maybeBinaryColumn = null; + } + } + + public String getMaybeStringColumn() { + return this.maybeStringColumn; + } + + public ParquetThriftCompat setMaybeStringColumn(String maybeStringColumn) { + this.maybeStringColumn = maybeStringColumn; + return this; + } + + public void unsetMaybeStringColumn() { + this.maybeStringColumn = null; + } + + /** Returns true if field maybeStringColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeStringColumn() { + return this.maybeStringColumn != null; + } + + public void setMaybeStringColumnIsSet(boolean value) { + if (!value) { + this.maybeStringColumn = null; + } + } + + /** + * + * @see Suit + */ + public Suit getMaybeEnumColumn() { + return this.maybeEnumColumn; + } + + /** + * + * @see Suit + */ + public ParquetThriftCompat setMaybeEnumColumn(Suit maybeEnumColumn) { + this.maybeEnumColumn = maybeEnumColumn; + return this; + } + + public void unsetMaybeEnumColumn() { + this.maybeEnumColumn = null; + } + + /** Returns true if field maybeEnumColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeEnumColumn() { + return this.maybeEnumColumn != null; + } + + public void setMaybeEnumColumnIsSet(boolean value) { + if (!value) { + this.maybeEnumColumn = null; + } + } + + public int getStringsColumnSize() { + return (this.stringsColumn == null) ? 0 : this.stringsColumn.size(); + } + + public java.util.Iterator getStringsColumnIterator() { + return (this.stringsColumn == null) ? null : this.stringsColumn.iterator(); + } + + public void addToStringsColumn(String elem) { + if (this.stringsColumn == null) { + this.stringsColumn = new ArrayList(); + } + this.stringsColumn.add(elem); + } + + public List getStringsColumn() { + return this.stringsColumn; + } + + public ParquetThriftCompat setStringsColumn(List stringsColumn) { + this.stringsColumn = stringsColumn; + return this; + } + + public void unsetStringsColumn() { + this.stringsColumn = null; + } + + /** Returns true if field stringsColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetStringsColumn() { + return this.stringsColumn != null; + } + + public void setStringsColumnIsSet(boolean value) { + if (!value) { + this.stringsColumn = null; + } + } + + public int getIntSetColumnSize() { + return (this.intSetColumn == null) ? 0 : this.intSetColumn.size(); + } + + public java.util.Iterator getIntSetColumnIterator() { + return (this.intSetColumn == null) ? null : this.intSetColumn.iterator(); + } + + public void addToIntSetColumn(int elem) { + if (this.intSetColumn == null) { + this.intSetColumn = new HashSet(); + } + this.intSetColumn.add(elem); + } + + public Set getIntSetColumn() { + return this.intSetColumn; + } + + public ParquetThriftCompat setIntSetColumn(Set intSetColumn) { + this.intSetColumn = intSetColumn; + return this; + } + + public void unsetIntSetColumn() { + this.intSetColumn = null; + } + + /** Returns true if field intSetColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetIntSetColumn() { + return this.intSetColumn != null; + } + + public void setIntSetColumnIsSet(boolean value) { + if (!value) { + this.intSetColumn = null; + } + } + + public int getIntToStringColumnSize() { + return (this.intToStringColumn == null) ? 0 : this.intToStringColumn.size(); + } + + public void putToIntToStringColumn(int key, String val) { + if (this.intToStringColumn == null) { + this.intToStringColumn = new HashMap(); + } + this.intToStringColumn.put(key, val); + } + + public Map getIntToStringColumn() { + return this.intToStringColumn; + } + + public ParquetThriftCompat setIntToStringColumn(Map intToStringColumn) { + this.intToStringColumn = intToStringColumn; + return this; + } + + public void unsetIntToStringColumn() { + this.intToStringColumn = null; + } + + /** Returns true if field intToStringColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetIntToStringColumn() { + return this.intToStringColumn != null; + } + + public void setIntToStringColumnIsSet(boolean value) { + if (!value) { + this.intToStringColumn = null; + } + } + + public int getComplexColumnSize() { + return (this.complexColumn == null) ? 0 : this.complexColumn.size(); + } + + public void putToComplexColumn(int key, List val) { + if (this.complexColumn == null) { + this.complexColumn = new HashMap>(); + } + this.complexColumn.put(key, val); + } + + public Map> getComplexColumn() { + return this.complexColumn; + } + + public ParquetThriftCompat setComplexColumn(Map> complexColumn) { + this.complexColumn = complexColumn; + return this; + } + + public void unsetComplexColumn() { + this.complexColumn = null; + } + + /** Returns true if field complexColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetComplexColumn() { + return this.complexColumn != null; + } + + public void setComplexColumnIsSet(boolean value) { + if (!value) { + this.complexColumn = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case BOOL_COLUMN: + if (value == null) { + unsetBoolColumn(); + } else { + setBoolColumn((Boolean)value); + } + break; + + case BYTE_COLUMN: + if (value == null) { + unsetByteColumn(); + } else { + setByteColumn((Byte)value); + } + break; + + case SHORT_COLUMN: + if (value == null) { + unsetShortColumn(); + } else { + setShortColumn((Short)value); + } + break; + + case INT_COLUMN: + if (value == null) { + unsetIntColumn(); + } else { + setIntColumn((Integer)value); + } + break; + + case LONG_COLUMN: + if (value == null) { + unsetLongColumn(); + } else { + setLongColumn((Long)value); + } + break; + + case DOUBLE_COLUMN: + if (value == null) { + unsetDoubleColumn(); + } else { + setDoubleColumn((Double)value); + } + break; + + case BINARY_COLUMN: + if (value == null) { + unsetBinaryColumn(); + } else { + setBinaryColumn((ByteBuffer)value); + } + break; + + case STRING_COLUMN: + if (value == null) { + unsetStringColumn(); + } else { + setStringColumn((String)value); + } + break; + + case ENUM_COLUMN: + if (value == null) { + unsetEnumColumn(); + } else { + setEnumColumn((Suit)value); + } + break; + + case MAYBE_BOOL_COLUMN: + if (value == null) { + unsetMaybeBoolColumn(); + } else { + setMaybeBoolColumn((Boolean)value); + } + break; + + case MAYBE_BYTE_COLUMN: + if (value == null) { + unsetMaybeByteColumn(); + } else { + setMaybeByteColumn((Byte)value); + } + break; + + case MAYBE_SHORT_COLUMN: + if (value == null) { + unsetMaybeShortColumn(); + } else { + setMaybeShortColumn((Short)value); + } + break; + + case MAYBE_INT_COLUMN: + if (value == null) { + unsetMaybeIntColumn(); + } else { + setMaybeIntColumn((Integer)value); + } + break; + + case MAYBE_LONG_COLUMN: + if (value == null) { + unsetMaybeLongColumn(); + } else { + setMaybeLongColumn((Long)value); + } + break; + + case MAYBE_DOUBLE_COLUMN: + if (value == null) { + unsetMaybeDoubleColumn(); + } else { + setMaybeDoubleColumn((Double)value); + } + break; + + case MAYBE_BINARY_COLUMN: + if (value == null) { + unsetMaybeBinaryColumn(); + } else { + setMaybeBinaryColumn((ByteBuffer)value); + } + break; + + case MAYBE_STRING_COLUMN: + if (value == null) { + unsetMaybeStringColumn(); + } else { + setMaybeStringColumn((String)value); + } + break; + + case MAYBE_ENUM_COLUMN: + if (value == null) { + unsetMaybeEnumColumn(); + } else { + setMaybeEnumColumn((Suit)value); + } + break; + + case STRINGS_COLUMN: + if (value == null) { + unsetStringsColumn(); + } else { + setStringsColumn((List)value); + } + break; + + case INT_SET_COLUMN: + if (value == null) { + unsetIntSetColumn(); + } else { + setIntSetColumn((Set)value); + } + break; + + case INT_TO_STRING_COLUMN: + if (value == null) { + unsetIntToStringColumn(); + } else { + setIntToStringColumn((Map)value); + } + break; + + case COMPLEX_COLUMN: + if (value == null) { + unsetComplexColumn(); + } else { + setComplexColumn((Map>)value); + } + break; + + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case BOOL_COLUMN: + return Boolean.valueOf(isBoolColumn()); + + case BYTE_COLUMN: + return Byte.valueOf(getByteColumn()); + + case SHORT_COLUMN: + return Short.valueOf(getShortColumn()); + + case INT_COLUMN: + return Integer.valueOf(getIntColumn()); + + case LONG_COLUMN: + return Long.valueOf(getLongColumn()); + + case DOUBLE_COLUMN: + return Double.valueOf(getDoubleColumn()); + + case BINARY_COLUMN: + return getBinaryColumn(); + + case STRING_COLUMN: + return getStringColumn(); + + case ENUM_COLUMN: + return getEnumColumn(); + + case MAYBE_BOOL_COLUMN: + return Boolean.valueOf(isMaybeBoolColumn()); + + case MAYBE_BYTE_COLUMN: + return Byte.valueOf(getMaybeByteColumn()); + + case MAYBE_SHORT_COLUMN: + return Short.valueOf(getMaybeShortColumn()); + + case MAYBE_INT_COLUMN: + return Integer.valueOf(getMaybeIntColumn()); + + case MAYBE_LONG_COLUMN: + return Long.valueOf(getMaybeLongColumn()); + + case MAYBE_DOUBLE_COLUMN: + return Double.valueOf(getMaybeDoubleColumn()); + + case MAYBE_BINARY_COLUMN: + return getMaybeBinaryColumn(); + + case MAYBE_STRING_COLUMN: + return getMaybeStringColumn(); + + case MAYBE_ENUM_COLUMN: + return getMaybeEnumColumn(); + + case STRINGS_COLUMN: + return getStringsColumn(); + + case INT_SET_COLUMN: + return getIntSetColumn(); + + case INT_TO_STRING_COLUMN: + return getIntToStringColumn(); + + case COMPLEX_COLUMN: + return getComplexColumn(); + + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case BOOL_COLUMN: + return isSetBoolColumn(); + case BYTE_COLUMN: + return isSetByteColumn(); + case SHORT_COLUMN: + return isSetShortColumn(); + case INT_COLUMN: + return isSetIntColumn(); + case LONG_COLUMN: + return isSetLongColumn(); + case DOUBLE_COLUMN: + return isSetDoubleColumn(); + case BINARY_COLUMN: + return isSetBinaryColumn(); + case STRING_COLUMN: + return isSetStringColumn(); + case ENUM_COLUMN: + return isSetEnumColumn(); + case MAYBE_BOOL_COLUMN: + return isSetMaybeBoolColumn(); + case MAYBE_BYTE_COLUMN: + return isSetMaybeByteColumn(); + case MAYBE_SHORT_COLUMN: + return isSetMaybeShortColumn(); + case MAYBE_INT_COLUMN: + return isSetMaybeIntColumn(); + case MAYBE_LONG_COLUMN: + return isSetMaybeLongColumn(); + case MAYBE_DOUBLE_COLUMN: + return isSetMaybeDoubleColumn(); + case MAYBE_BINARY_COLUMN: + return isSetMaybeBinaryColumn(); + case MAYBE_STRING_COLUMN: + return isSetMaybeStringColumn(); + case MAYBE_ENUM_COLUMN: + return isSetMaybeEnumColumn(); + case STRINGS_COLUMN: + return isSetStringsColumn(); + case INT_SET_COLUMN: + return isSetIntSetColumn(); + case INT_TO_STRING_COLUMN: + return isSetIntToStringColumn(); + case COMPLEX_COLUMN: + return isSetComplexColumn(); + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) + return false; + if (that instanceof ParquetThriftCompat) + return this.equals((ParquetThriftCompat)that); + return false; + } + + public boolean equals(ParquetThriftCompat that) { + if (that == null) + return false; + + boolean this_present_boolColumn = true; + boolean that_present_boolColumn = true; + if (this_present_boolColumn || that_present_boolColumn) { + if (!(this_present_boolColumn && that_present_boolColumn)) + return false; + if (this.boolColumn != that.boolColumn) + return false; + } + + boolean this_present_byteColumn = true; + boolean that_present_byteColumn = true; + if (this_present_byteColumn || that_present_byteColumn) { + if (!(this_present_byteColumn && that_present_byteColumn)) + return false; + if (this.byteColumn != that.byteColumn) + return false; + } + + boolean this_present_shortColumn = true; + boolean that_present_shortColumn = true; + if (this_present_shortColumn || that_present_shortColumn) { + if (!(this_present_shortColumn && that_present_shortColumn)) + return false; + if (this.shortColumn != that.shortColumn) + return false; + } + + boolean this_present_intColumn = true; + boolean that_present_intColumn = true; + if (this_present_intColumn || that_present_intColumn) { + if (!(this_present_intColumn && that_present_intColumn)) + return false; + if (this.intColumn != that.intColumn) + return false; + } + + boolean this_present_longColumn = true; + boolean that_present_longColumn = true; + if (this_present_longColumn || that_present_longColumn) { + if (!(this_present_longColumn && that_present_longColumn)) + return false; + if (this.longColumn != that.longColumn) + return false; + } + + boolean this_present_doubleColumn = true; + boolean that_present_doubleColumn = true; + if (this_present_doubleColumn || that_present_doubleColumn) { + if (!(this_present_doubleColumn && that_present_doubleColumn)) + return false; + if (this.doubleColumn != that.doubleColumn) + return false; + } + + boolean this_present_binaryColumn = true && this.isSetBinaryColumn(); + boolean that_present_binaryColumn = true && that.isSetBinaryColumn(); + if (this_present_binaryColumn || that_present_binaryColumn) { + if (!(this_present_binaryColumn && that_present_binaryColumn)) + return false; + if (!this.binaryColumn.equals(that.binaryColumn)) + return false; + } + + boolean this_present_stringColumn = true && this.isSetStringColumn(); + boolean that_present_stringColumn = true && that.isSetStringColumn(); + if (this_present_stringColumn || that_present_stringColumn) { + if (!(this_present_stringColumn && that_present_stringColumn)) + return false; + if (!this.stringColumn.equals(that.stringColumn)) + return false; + } + + boolean this_present_enumColumn = true && this.isSetEnumColumn(); + boolean that_present_enumColumn = true && that.isSetEnumColumn(); + if (this_present_enumColumn || that_present_enumColumn) { + if (!(this_present_enumColumn && that_present_enumColumn)) + return false; + if (!this.enumColumn.equals(that.enumColumn)) + return false; + } + + boolean this_present_maybeBoolColumn = true && this.isSetMaybeBoolColumn(); + boolean that_present_maybeBoolColumn = true && that.isSetMaybeBoolColumn(); + if (this_present_maybeBoolColumn || that_present_maybeBoolColumn) { + if (!(this_present_maybeBoolColumn && that_present_maybeBoolColumn)) + return false; + if (this.maybeBoolColumn != that.maybeBoolColumn) + return false; + } + + boolean this_present_maybeByteColumn = true && this.isSetMaybeByteColumn(); + boolean that_present_maybeByteColumn = true && that.isSetMaybeByteColumn(); + if (this_present_maybeByteColumn || that_present_maybeByteColumn) { + if (!(this_present_maybeByteColumn && that_present_maybeByteColumn)) + return false; + if (this.maybeByteColumn != that.maybeByteColumn) + return false; + } + + boolean this_present_maybeShortColumn = true && this.isSetMaybeShortColumn(); + boolean that_present_maybeShortColumn = true && that.isSetMaybeShortColumn(); + if (this_present_maybeShortColumn || that_present_maybeShortColumn) { + if (!(this_present_maybeShortColumn && that_present_maybeShortColumn)) + return false; + if (this.maybeShortColumn != that.maybeShortColumn) + return false; + } + + boolean this_present_maybeIntColumn = true && this.isSetMaybeIntColumn(); + boolean that_present_maybeIntColumn = true && that.isSetMaybeIntColumn(); + if (this_present_maybeIntColumn || that_present_maybeIntColumn) { + if (!(this_present_maybeIntColumn && that_present_maybeIntColumn)) + return false; + if (this.maybeIntColumn != that.maybeIntColumn) + return false; + } + + boolean this_present_maybeLongColumn = true && this.isSetMaybeLongColumn(); + boolean that_present_maybeLongColumn = true && that.isSetMaybeLongColumn(); + if (this_present_maybeLongColumn || that_present_maybeLongColumn) { + if (!(this_present_maybeLongColumn && that_present_maybeLongColumn)) + return false; + if (this.maybeLongColumn != that.maybeLongColumn) + return false; + } + + boolean this_present_maybeDoubleColumn = true && this.isSetMaybeDoubleColumn(); + boolean that_present_maybeDoubleColumn = true && that.isSetMaybeDoubleColumn(); + if (this_present_maybeDoubleColumn || that_present_maybeDoubleColumn) { + if (!(this_present_maybeDoubleColumn && that_present_maybeDoubleColumn)) + return false; + if (this.maybeDoubleColumn != that.maybeDoubleColumn) + return false; + } + + boolean this_present_maybeBinaryColumn = true && this.isSetMaybeBinaryColumn(); + boolean that_present_maybeBinaryColumn = true && that.isSetMaybeBinaryColumn(); + if (this_present_maybeBinaryColumn || that_present_maybeBinaryColumn) { + if (!(this_present_maybeBinaryColumn && that_present_maybeBinaryColumn)) + return false; + if (!this.maybeBinaryColumn.equals(that.maybeBinaryColumn)) + return false; + } + + boolean this_present_maybeStringColumn = true && this.isSetMaybeStringColumn(); + boolean that_present_maybeStringColumn = true && that.isSetMaybeStringColumn(); + if (this_present_maybeStringColumn || that_present_maybeStringColumn) { + if (!(this_present_maybeStringColumn && that_present_maybeStringColumn)) + return false; + if (!this.maybeStringColumn.equals(that.maybeStringColumn)) + return false; + } + + boolean this_present_maybeEnumColumn = true && this.isSetMaybeEnumColumn(); + boolean that_present_maybeEnumColumn = true && that.isSetMaybeEnumColumn(); + if (this_present_maybeEnumColumn || that_present_maybeEnumColumn) { + if (!(this_present_maybeEnumColumn && that_present_maybeEnumColumn)) + return false; + if (!this.maybeEnumColumn.equals(that.maybeEnumColumn)) + return false; + } + + boolean this_present_stringsColumn = true && this.isSetStringsColumn(); + boolean that_present_stringsColumn = true && that.isSetStringsColumn(); + if (this_present_stringsColumn || that_present_stringsColumn) { + if (!(this_present_stringsColumn && that_present_stringsColumn)) + return false; + if (!this.stringsColumn.equals(that.stringsColumn)) + return false; + } + + boolean this_present_intSetColumn = true && this.isSetIntSetColumn(); + boolean that_present_intSetColumn = true && that.isSetIntSetColumn(); + if (this_present_intSetColumn || that_present_intSetColumn) { + if (!(this_present_intSetColumn && that_present_intSetColumn)) + return false; + if (!this.intSetColumn.equals(that.intSetColumn)) + return false; + } + + boolean this_present_intToStringColumn = true && this.isSetIntToStringColumn(); + boolean that_present_intToStringColumn = true && that.isSetIntToStringColumn(); + if (this_present_intToStringColumn || that_present_intToStringColumn) { + if (!(this_present_intToStringColumn && that_present_intToStringColumn)) + return false; + if (!this.intToStringColumn.equals(that.intToStringColumn)) + return false; + } + + boolean this_present_complexColumn = true && this.isSetComplexColumn(); + boolean that_present_complexColumn = true && that.isSetComplexColumn(); + if (this_present_complexColumn || that_present_complexColumn) { + if (!(this_present_complexColumn && that_present_complexColumn)) + return false; + if (!this.complexColumn.equals(that.complexColumn)) + return false; + } + + return true; + } + + @Override + public int hashCode() { + List list = new ArrayList(); + + boolean present_boolColumn = true; + list.add(present_boolColumn); + if (present_boolColumn) + list.add(boolColumn); + + boolean present_byteColumn = true; + list.add(present_byteColumn); + if (present_byteColumn) + list.add(byteColumn); + + boolean present_shortColumn = true; + list.add(present_shortColumn); + if (present_shortColumn) + list.add(shortColumn); + + boolean present_intColumn = true; + list.add(present_intColumn); + if (present_intColumn) + list.add(intColumn); + + boolean present_longColumn = true; + list.add(present_longColumn); + if (present_longColumn) + list.add(longColumn); + + boolean present_doubleColumn = true; + list.add(present_doubleColumn); + if (present_doubleColumn) + list.add(doubleColumn); + + boolean present_binaryColumn = true && (isSetBinaryColumn()); + list.add(present_binaryColumn); + if (present_binaryColumn) + list.add(binaryColumn); + + boolean present_stringColumn = true && (isSetStringColumn()); + list.add(present_stringColumn); + if (present_stringColumn) + list.add(stringColumn); + + boolean present_enumColumn = true && (isSetEnumColumn()); + list.add(present_enumColumn); + if (present_enumColumn) + list.add(enumColumn.getValue()); + + boolean present_maybeBoolColumn = true && (isSetMaybeBoolColumn()); + list.add(present_maybeBoolColumn); + if (present_maybeBoolColumn) + list.add(maybeBoolColumn); + + boolean present_maybeByteColumn = true && (isSetMaybeByteColumn()); + list.add(present_maybeByteColumn); + if (present_maybeByteColumn) + list.add(maybeByteColumn); + + boolean present_maybeShortColumn = true && (isSetMaybeShortColumn()); + list.add(present_maybeShortColumn); + if (present_maybeShortColumn) + list.add(maybeShortColumn); + + boolean present_maybeIntColumn = true && (isSetMaybeIntColumn()); + list.add(present_maybeIntColumn); + if (present_maybeIntColumn) + list.add(maybeIntColumn); + + boolean present_maybeLongColumn = true && (isSetMaybeLongColumn()); + list.add(present_maybeLongColumn); + if (present_maybeLongColumn) + list.add(maybeLongColumn); + + boolean present_maybeDoubleColumn = true && (isSetMaybeDoubleColumn()); + list.add(present_maybeDoubleColumn); + if (present_maybeDoubleColumn) + list.add(maybeDoubleColumn); + + boolean present_maybeBinaryColumn = true && (isSetMaybeBinaryColumn()); + list.add(present_maybeBinaryColumn); + if (present_maybeBinaryColumn) + list.add(maybeBinaryColumn); + + boolean present_maybeStringColumn = true && (isSetMaybeStringColumn()); + list.add(present_maybeStringColumn); + if (present_maybeStringColumn) + list.add(maybeStringColumn); + + boolean present_maybeEnumColumn = true && (isSetMaybeEnumColumn()); + list.add(present_maybeEnumColumn); + if (present_maybeEnumColumn) + list.add(maybeEnumColumn.getValue()); + + boolean present_stringsColumn = true && (isSetStringsColumn()); + list.add(present_stringsColumn); + if (present_stringsColumn) + list.add(stringsColumn); + + boolean present_intSetColumn = true && (isSetIntSetColumn()); + list.add(present_intSetColumn); + if (present_intSetColumn) + list.add(intSetColumn); + + boolean present_intToStringColumn = true && (isSetIntToStringColumn()); + list.add(present_intToStringColumn); + if (present_intToStringColumn) + list.add(intToStringColumn); + + boolean present_complexColumn = true && (isSetComplexColumn()); + list.add(present_complexColumn); + if (present_complexColumn) + list.add(complexColumn); + + return list.hashCode(); + } + + @Override + public int compareTo(ParquetThriftCompat other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + + lastComparison = Boolean.valueOf(isSetBoolColumn()).compareTo(other.isSetBoolColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetBoolColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.boolColumn, other.boolColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetByteColumn()).compareTo(other.isSetByteColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetByteColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.byteColumn, other.byteColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetShortColumn()).compareTo(other.isSetShortColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetShortColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.shortColumn, other.shortColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetIntColumn()).compareTo(other.isSetIntColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetIntColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.intColumn, other.intColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLongColumn()).compareTo(other.isSetLongColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLongColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.longColumn, other.longColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetDoubleColumn()).compareTo(other.isSetDoubleColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetDoubleColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.doubleColumn, other.doubleColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetBinaryColumn()).compareTo(other.isSetBinaryColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetBinaryColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.binaryColumn, other.binaryColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetStringColumn()).compareTo(other.isSetStringColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetStringColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.stringColumn, other.stringColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetEnumColumn()).compareTo(other.isSetEnumColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetEnumColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.enumColumn, other.enumColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeBoolColumn()).compareTo(other.isSetMaybeBoolColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeBoolColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeBoolColumn, other.maybeBoolColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeByteColumn()).compareTo(other.isSetMaybeByteColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeByteColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeByteColumn, other.maybeByteColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeShortColumn()).compareTo(other.isSetMaybeShortColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeShortColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeShortColumn, other.maybeShortColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeIntColumn()).compareTo(other.isSetMaybeIntColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeIntColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeIntColumn, other.maybeIntColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeLongColumn()).compareTo(other.isSetMaybeLongColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeLongColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeLongColumn, other.maybeLongColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeDoubleColumn()).compareTo(other.isSetMaybeDoubleColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeDoubleColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeDoubleColumn, other.maybeDoubleColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeBinaryColumn()).compareTo(other.isSetMaybeBinaryColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeBinaryColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeBinaryColumn, other.maybeBinaryColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeStringColumn()).compareTo(other.isSetMaybeStringColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeStringColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeStringColumn, other.maybeStringColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeEnumColumn()).compareTo(other.isSetMaybeEnumColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeEnumColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeEnumColumn, other.maybeEnumColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetStringsColumn()).compareTo(other.isSetStringsColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetStringsColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.stringsColumn, other.stringsColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetIntSetColumn()).compareTo(other.isSetIntSetColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetIntSetColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.intSetColumn, other.intSetColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetIntToStringColumn()).compareTo(other.isSetIntToStringColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetIntToStringColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.intToStringColumn, other.intToStringColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetComplexColumn()).compareTo(other.isSetComplexColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetComplexColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.complexColumn, other.complexColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("ParquetThriftCompat("); + boolean first = true; + + sb.append("boolColumn:"); + sb.append(this.boolColumn); + first = false; + if (!first) sb.append(", "); + sb.append("byteColumn:"); + sb.append(this.byteColumn); + first = false; + if (!first) sb.append(", "); + sb.append("shortColumn:"); + sb.append(this.shortColumn); + first = false; + if (!first) sb.append(", "); + sb.append("intColumn:"); + sb.append(this.intColumn); + first = false; + if (!first) sb.append(", "); + sb.append("longColumn:"); + sb.append(this.longColumn); + first = false; + if (!first) sb.append(", "); + sb.append("doubleColumn:"); + sb.append(this.doubleColumn); + first = false; + if (!first) sb.append(", "); + sb.append("binaryColumn:"); + if (this.binaryColumn == null) { + sb.append("null"); + } else { + org.apache.thrift.TBaseHelper.toString(this.binaryColumn, sb); + } + first = false; + if (!first) sb.append(", "); + sb.append("stringColumn:"); + if (this.stringColumn == null) { + sb.append("null"); + } else { + sb.append(this.stringColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("enumColumn:"); + if (this.enumColumn == null) { + sb.append("null"); + } else { + sb.append(this.enumColumn); + } + first = false; + if (isSetMaybeBoolColumn()) { + if (!first) sb.append(", "); + sb.append("maybeBoolColumn:"); + sb.append(this.maybeBoolColumn); + first = false; + } + if (isSetMaybeByteColumn()) { + if (!first) sb.append(", "); + sb.append("maybeByteColumn:"); + sb.append(this.maybeByteColumn); + first = false; + } + if (isSetMaybeShortColumn()) { + if (!first) sb.append(", "); + sb.append("maybeShortColumn:"); + sb.append(this.maybeShortColumn); + first = false; + } + if (isSetMaybeIntColumn()) { + if (!first) sb.append(", "); + sb.append("maybeIntColumn:"); + sb.append(this.maybeIntColumn); + first = false; + } + if (isSetMaybeLongColumn()) { + if (!first) sb.append(", "); + sb.append("maybeLongColumn:"); + sb.append(this.maybeLongColumn); + first = false; + } + if (isSetMaybeDoubleColumn()) { + if (!first) sb.append(", "); + sb.append("maybeDoubleColumn:"); + sb.append(this.maybeDoubleColumn); + first = false; + } + if (isSetMaybeBinaryColumn()) { + if (!first) sb.append(", "); + sb.append("maybeBinaryColumn:"); + if (this.maybeBinaryColumn == null) { + sb.append("null"); + } else { + org.apache.thrift.TBaseHelper.toString(this.maybeBinaryColumn, sb); + } + first = false; + } + if (isSetMaybeStringColumn()) { + if (!first) sb.append(", "); + sb.append("maybeStringColumn:"); + if (this.maybeStringColumn == null) { + sb.append("null"); + } else { + sb.append(this.maybeStringColumn); + } + first = false; + } + if (isSetMaybeEnumColumn()) { + if (!first) sb.append(", "); + sb.append("maybeEnumColumn:"); + if (this.maybeEnumColumn == null) { + sb.append("null"); + } else { + sb.append(this.maybeEnumColumn); + } + first = false; + } + if (!first) sb.append(", "); + sb.append("stringsColumn:"); + if (this.stringsColumn == null) { + sb.append("null"); + } else { + sb.append(this.stringsColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("intSetColumn:"); + if (this.intSetColumn == null) { + sb.append("null"); + } else { + sb.append(this.intSetColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("intToStringColumn:"); + if (this.intToStringColumn == null) { + sb.append("null"); + } else { + sb.append(this.intToStringColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("complexColumn:"); + if (this.complexColumn == null) { + sb.append("null"); + } else { + sb.append(this.complexColumn); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + // alas, we cannot check 'boolColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'byteColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'shortColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'intColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'longColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'doubleColumn' because it's a primitive and you chose the non-beans generator. + if (binaryColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'binaryColumn' was not present! Struct: " + toString()); + } + if (stringColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'stringColumn' was not present! Struct: " + toString()); + } + if (enumColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'enumColumn' was not present! Struct: " + toString()); + } + if (stringsColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'stringsColumn' was not present! Struct: " + toString()); + } + if (intSetColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'intSetColumn' was not present! Struct: " + toString()); + } + if (intToStringColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'intToStringColumn' was not present! Struct: " + toString()); + } + if (complexColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'complexColumn' was not present! Struct: " + toString()); + } + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor. + __isset_bitfield = 0; + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class ParquetThriftCompatStandardSchemeFactory implements SchemeFactory { + public ParquetThriftCompatStandardScheme getScheme() { + return new ParquetThriftCompatStandardScheme(); + } + } + + private static class ParquetThriftCompatStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, ParquetThriftCompat struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // BOOL_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.BOOL) { + struct.boolColumn = iprot.readBool(); + struct.setBoolColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 2: // BYTE_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.BYTE) { + struct.byteColumn = iprot.readByte(); + struct.setByteColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 3: // SHORT_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I16) { + struct.shortColumn = iprot.readI16(); + struct.setShortColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 4: // INT_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.intColumn = iprot.readI32(); + struct.setIntColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 5: // LONG_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I64) { + struct.longColumn = iprot.readI64(); + struct.setLongColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 6: // DOUBLE_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.DOUBLE) { + struct.doubleColumn = iprot.readDouble(); + struct.setDoubleColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 7: // BINARY_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.binaryColumn = iprot.readBinary(); + struct.setBinaryColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 8: // STRING_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.stringColumn = iprot.readString(); + struct.setStringColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 9: // ENUM_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.enumColumn = org.apache.spark.sql.parquet.test.thrift.Suit.findByValue(iprot.readI32()); + struct.setEnumColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 10: // MAYBE_BOOL_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.BOOL) { + struct.maybeBoolColumn = iprot.readBool(); + struct.setMaybeBoolColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 11: // MAYBE_BYTE_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.BYTE) { + struct.maybeByteColumn = iprot.readByte(); + struct.setMaybeByteColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 12: // MAYBE_SHORT_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I16) { + struct.maybeShortColumn = iprot.readI16(); + struct.setMaybeShortColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 13: // MAYBE_INT_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.maybeIntColumn = iprot.readI32(); + struct.setMaybeIntColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 14: // MAYBE_LONG_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I64) { + struct.maybeLongColumn = iprot.readI64(); + struct.setMaybeLongColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 15: // MAYBE_DOUBLE_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.DOUBLE) { + struct.maybeDoubleColumn = iprot.readDouble(); + struct.setMaybeDoubleColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 16: // MAYBE_BINARY_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.maybeBinaryColumn = iprot.readBinary(); + struct.setMaybeBinaryColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 17: // MAYBE_STRING_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.maybeStringColumn = iprot.readString(); + struct.setMaybeStringColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 18: // MAYBE_ENUM_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.maybeEnumColumn = org.apache.spark.sql.parquet.test.thrift.Suit.findByValue(iprot.readI32()); + struct.setMaybeEnumColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 19: // STRINGS_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list8 = iprot.readListBegin(); + struct.stringsColumn = new ArrayList(_list8.size); + String _elem9; + for (int _i10 = 0; _i10 < _list8.size; ++_i10) + { + _elem9 = iprot.readString(); + struct.stringsColumn.add(_elem9); + } + iprot.readListEnd(); + } + struct.setStringsColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 20: // INT_SET_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.SET) { + { + org.apache.thrift.protocol.TSet _set11 = iprot.readSetBegin(); + struct.intSetColumn = new HashSet(2*_set11.size); + int _elem12; + for (int _i13 = 0; _i13 < _set11.size; ++_i13) + { + _elem12 = iprot.readI32(); + struct.intSetColumn.add(_elem12); + } + iprot.readSetEnd(); + } + struct.setIntSetColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 21: // INT_TO_STRING_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.MAP) { + { + org.apache.thrift.protocol.TMap _map14 = iprot.readMapBegin(); + struct.intToStringColumn = new HashMap(2*_map14.size); + int _key15; + String _val16; + for (int _i17 = 0; _i17 < _map14.size; ++_i17) + { + _key15 = iprot.readI32(); + _val16 = iprot.readString(); + struct.intToStringColumn.put(_key15, _val16); + } + iprot.readMapEnd(); + } + struct.setIntToStringColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 22: // COMPLEX_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.MAP) { + { + org.apache.thrift.protocol.TMap _map18 = iprot.readMapBegin(); + struct.complexColumn = new HashMap>(2*_map18.size); + int _key19; + List _val20; + for (int _i21 = 0; _i21 < _map18.size; ++_i21) + { + _key19 = iprot.readI32(); + { + org.apache.thrift.protocol.TList _list22 = iprot.readListBegin(); + _val20 = new ArrayList(_list22.size); + Nested _elem23; + for (int _i24 = 0; _i24 < _list22.size; ++_i24) + { + _elem23 = new Nested(); + _elem23.read(iprot); + _val20.add(_elem23); + } + iprot.readListEnd(); + } + struct.complexColumn.put(_key19, _val20); + } + iprot.readMapEnd(); + } + struct.setComplexColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + + // check for required fields of primitive type, which can't be checked in the validate method + if (!struct.isSetBoolColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'boolColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetByteColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'byteColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetShortColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'shortColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetIntColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'intColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetLongColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'longColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetDoubleColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'doubleColumn' was not found in serialized data! Struct: " + toString()); + } + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, ParquetThriftCompat struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + oprot.writeFieldBegin(BOOL_COLUMN_FIELD_DESC); + oprot.writeBool(struct.boolColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(BYTE_COLUMN_FIELD_DESC); + oprot.writeByte(struct.byteColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(SHORT_COLUMN_FIELD_DESC); + oprot.writeI16(struct.shortColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(INT_COLUMN_FIELD_DESC); + oprot.writeI32(struct.intColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(LONG_COLUMN_FIELD_DESC); + oprot.writeI64(struct.longColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(DOUBLE_COLUMN_FIELD_DESC); + oprot.writeDouble(struct.doubleColumn); + oprot.writeFieldEnd(); + if (struct.binaryColumn != null) { + oprot.writeFieldBegin(BINARY_COLUMN_FIELD_DESC); + oprot.writeBinary(struct.binaryColumn); + oprot.writeFieldEnd(); + } + if (struct.stringColumn != null) { + oprot.writeFieldBegin(STRING_COLUMN_FIELD_DESC); + oprot.writeString(struct.stringColumn); + oprot.writeFieldEnd(); + } + if (struct.enumColumn != null) { + oprot.writeFieldBegin(ENUM_COLUMN_FIELD_DESC); + oprot.writeI32(struct.enumColumn.getValue()); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeBoolColumn()) { + oprot.writeFieldBegin(MAYBE_BOOL_COLUMN_FIELD_DESC); + oprot.writeBool(struct.maybeBoolColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeByteColumn()) { + oprot.writeFieldBegin(MAYBE_BYTE_COLUMN_FIELD_DESC); + oprot.writeByte(struct.maybeByteColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeShortColumn()) { + oprot.writeFieldBegin(MAYBE_SHORT_COLUMN_FIELD_DESC); + oprot.writeI16(struct.maybeShortColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeIntColumn()) { + oprot.writeFieldBegin(MAYBE_INT_COLUMN_FIELD_DESC); + oprot.writeI32(struct.maybeIntColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeLongColumn()) { + oprot.writeFieldBegin(MAYBE_LONG_COLUMN_FIELD_DESC); + oprot.writeI64(struct.maybeLongColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeDoubleColumn()) { + oprot.writeFieldBegin(MAYBE_DOUBLE_COLUMN_FIELD_DESC); + oprot.writeDouble(struct.maybeDoubleColumn); + oprot.writeFieldEnd(); + } + if (struct.maybeBinaryColumn != null) { + if (struct.isSetMaybeBinaryColumn()) { + oprot.writeFieldBegin(MAYBE_BINARY_COLUMN_FIELD_DESC); + oprot.writeBinary(struct.maybeBinaryColumn); + oprot.writeFieldEnd(); + } + } + if (struct.maybeStringColumn != null) { + if (struct.isSetMaybeStringColumn()) { + oprot.writeFieldBegin(MAYBE_STRING_COLUMN_FIELD_DESC); + oprot.writeString(struct.maybeStringColumn); + oprot.writeFieldEnd(); + } + } + if (struct.maybeEnumColumn != null) { + if (struct.isSetMaybeEnumColumn()) { + oprot.writeFieldBegin(MAYBE_ENUM_COLUMN_FIELD_DESC); + oprot.writeI32(struct.maybeEnumColumn.getValue()); + oprot.writeFieldEnd(); + } + } + if (struct.stringsColumn != null) { + oprot.writeFieldBegin(STRINGS_COLUMN_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, struct.stringsColumn.size())); + for (String _iter25 : struct.stringsColumn) + { + oprot.writeString(_iter25); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.intSetColumn != null) { + oprot.writeFieldBegin(INT_SET_COLUMN_FIELD_DESC); + { + oprot.writeSetBegin(new org.apache.thrift.protocol.TSet(org.apache.thrift.protocol.TType.I32, struct.intSetColumn.size())); + for (int _iter26 : struct.intSetColumn) + { + oprot.writeI32(_iter26); + } + oprot.writeSetEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.intToStringColumn != null) { + oprot.writeFieldBegin(INT_TO_STRING_COLUMN_FIELD_DESC); + { + oprot.writeMapBegin(new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.I32, org.apache.thrift.protocol.TType.STRING, struct.intToStringColumn.size())); + for (Map.Entry _iter27 : struct.intToStringColumn.entrySet()) + { + oprot.writeI32(_iter27.getKey()); + oprot.writeString(_iter27.getValue()); + } + oprot.writeMapEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.complexColumn != null) { + oprot.writeFieldBegin(COMPLEX_COLUMN_FIELD_DESC); + { + oprot.writeMapBegin(new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.I32, org.apache.thrift.protocol.TType.LIST, struct.complexColumn.size())); + for (Map.Entry> _iter28 : struct.complexColumn.entrySet()) + { + oprot.writeI32(_iter28.getKey()); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, _iter28.getValue().size())); + for (Nested _iter29 : _iter28.getValue()) + { + _iter29.write(oprot); + } + oprot.writeListEnd(); + } + } + oprot.writeMapEnd(); + } + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class ParquetThriftCompatTupleSchemeFactory implements SchemeFactory { + public ParquetThriftCompatTupleScheme getScheme() { + return new ParquetThriftCompatTupleScheme(); + } + } + + private static class ParquetThriftCompatTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, ParquetThriftCompat struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + oprot.writeBool(struct.boolColumn); + oprot.writeByte(struct.byteColumn); + oprot.writeI16(struct.shortColumn); + oprot.writeI32(struct.intColumn); + oprot.writeI64(struct.longColumn); + oprot.writeDouble(struct.doubleColumn); + oprot.writeBinary(struct.binaryColumn); + oprot.writeString(struct.stringColumn); + oprot.writeI32(struct.enumColumn.getValue()); + { + oprot.writeI32(struct.stringsColumn.size()); + for (String _iter30 : struct.stringsColumn) + { + oprot.writeString(_iter30); + } + } + { + oprot.writeI32(struct.intSetColumn.size()); + for (int _iter31 : struct.intSetColumn) + { + oprot.writeI32(_iter31); + } + } + { + oprot.writeI32(struct.intToStringColumn.size()); + for (Map.Entry _iter32 : struct.intToStringColumn.entrySet()) + { + oprot.writeI32(_iter32.getKey()); + oprot.writeString(_iter32.getValue()); + } + } + { + oprot.writeI32(struct.complexColumn.size()); + for (Map.Entry> _iter33 : struct.complexColumn.entrySet()) + { + oprot.writeI32(_iter33.getKey()); + { + oprot.writeI32(_iter33.getValue().size()); + for (Nested _iter34 : _iter33.getValue()) + { + _iter34.write(oprot); + } + } + } + } + BitSet optionals = new BitSet(); + if (struct.isSetMaybeBoolColumn()) { + optionals.set(0); + } + if (struct.isSetMaybeByteColumn()) { + optionals.set(1); + } + if (struct.isSetMaybeShortColumn()) { + optionals.set(2); + } + if (struct.isSetMaybeIntColumn()) { + optionals.set(3); + } + if (struct.isSetMaybeLongColumn()) { + optionals.set(4); + } + if (struct.isSetMaybeDoubleColumn()) { + optionals.set(5); + } + if (struct.isSetMaybeBinaryColumn()) { + optionals.set(6); + } + if (struct.isSetMaybeStringColumn()) { + optionals.set(7); + } + if (struct.isSetMaybeEnumColumn()) { + optionals.set(8); + } + oprot.writeBitSet(optionals, 9); + if (struct.isSetMaybeBoolColumn()) { + oprot.writeBool(struct.maybeBoolColumn); + } + if (struct.isSetMaybeByteColumn()) { + oprot.writeByte(struct.maybeByteColumn); + } + if (struct.isSetMaybeShortColumn()) { + oprot.writeI16(struct.maybeShortColumn); + } + if (struct.isSetMaybeIntColumn()) { + oprot.writeI32(struct.maybeIntColumn); + } + if (struct.isSetMaybeLongColumn()) { + oprot.writeI64(struct.maybeLongColumn); + } + if (struct.isSetMaybeDoubleColumn()) { + oprot.writeDouble(struct.maybeDoubleColumn); + } + if (struct.isSetMaybeBinaryColumn()) { + oprot.writeBinary(struct.maybeBinaryColumn); + } + if (struct.isSetMaybeStringColumn()) { + oprot.writeString(struct.maybeStringColumn); + } + if (struct.isSetMaybeEnumColumn()) { + oprot.writeI32(struct.maybeEnumColumn.getValue()); + } + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, ParquetThriftCompat struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + struct.boolColumn = iprot.readBool(); + struct.setBoolColumnIsSet(true); + struct.byteColumn = iprot.readByte(); + struct.setByteColumnIsSet(true); + struct.shortColumn = iprot.readI16(); + struct.setShortColumnIsSet(true); + struct.intColumn = iprot.readI32(); + struct.setIntColumnIsSet(true); + struct.longColumn = iprot.readI64(); + struct.setLongColumnIsSet(true); + struct.doubleColumn = iprot.readDouble(); + struct.setDoubleColumnIsSet(true); + struct.binaryColumn = iprot.readBinary(); + struct.setBinaryColumnIsSet(true); + struct.stringColumn = iprot.readString(); + struct.setStringColumnIsSet(true); + struct.enumColumn = org.apache.spark.sql.parquet.test.thrift.Suit.findByValue(iprot.readI32()); + struct.setEnumColumnIsSet(true); + { + org.apache.thrift.protocol.TList _list35 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.stringsColumn = new ArrayList(_list35.size); + String _elem36; + for (int _i37 = 0; _i37 < _list35.size; ++_i37) + { + _elem36 = iprot.readString(); + struct.stringsColumn.add(_elem36); + } + } + struct.setStringsColumnIsSet(true); + { + org.apache.thrift.protocol.TSet _set38 = new org.apache.thrift.protocol.TSet(org.apache.thrift.protocol.TType.I32, iprot.readI32()); + struct.intSetColumn = new HashSet(2*_set38.size); + int _elem39; + for (int _i40 = 0; _i40 < _set38.size; ++_i40) + { + _elem39 = iprot.readI32(); + struct.intSetColumn.add(_elem39); + } + } + struct.setIntSetColumnIsSet(true); + { + org.apache.thrift.protocol.TMap _map41 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.I32, org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.intToStringColumn = new HashMap(2*_map41.size); + int _key42; + String _val43; + for (int _i44 = 0; _i44 < _map41.size; ++_i44) + { + _key42 = iprot.readI32(); + _val43 = iprot.readString(); + struct.intToStringColumn.put(_key42, _val43); + } + } + struct.setIntToStringColumnIsSet(true); + { + org.apache.thrift.protocol.TMap _map45 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.I32, org.apache.thrift.protocol.TType.LIST, iprot.readI32()); + struct.complexColumn = new HashMap>(2*_map45.size); + int _key46; + List _val47; + for (int _i48 = 0; _i48 < _map45.size; ++_i48) + { + _key46 = iprot.readI32(); + { + org.apache.thrift.protocol.TList _list49 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, iprot.readI32()); + _val47 = new ArrayList(_list49.size); + Nested _elem50; + for (int _i51 = 0; _i51 < _list49.size; ++_i51) + { + _elem50 = new Nested(); + _elem50.read(iprot); + _val47.add(_elem50); + } + } + struct.complexColumn.put(_key46, _val47); + } + } + struct.setComplexColumnIsSet(true); + BitSet incoming = iprot.readBitSet(9); + if (incoming.get(0)) { + struct.maybeBoolColumn = iprot.readBool(); + struct.setMaybeBoolColumnIsSet(true); + } + if (incoming.get(1)) { + struct.maybeByteColumn = iprot.readByte(); + struct.setMaybeByteColumnIsSet(true); + } + if (incoming.get(2)) { + struct.maybeShortColumn = iprot.readI16(); + struct.setMaybeShortColumnIsSet(true); + } + if (incoming.get(3)) { + struct.maybeIntColumn = iprot.readI32(); + struct.setMaybeIntColumnIsSet(true); + } + if (incoming.get(4)) { + struct.maybeLongColumn = iprot.readI64(); + struct.setMaybeLongColumnIsSet(true); + } + if (incoming.get(5)) { + struct.maybeDoubleColumn = iprot.readDouble(); + struct.setMaybeDoubleColumnIsSet(true); + } + if (incoming.get(6)) { + struct.maybeBinaryColumn = iprot.readBinary(); + struct.setMaybeBinaryColumnIsSet(true); + } + if (incoming.get(7)) { + struct.maybeStringColumn = iprot.readString(); + struct.setMaybeStringColumnIsSet(true); + } + if (incoming.get(8)) { + struct.maybeEnumColumn = org.apache.spark.sql.parquet.test.thrift.Suit.findByValue(iprot.readI32()); + struct.setMaybeEnumColumnIsSet(true); + } + } + } + +} + diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Suit.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Suit.java new file mode 100644 index 0000000000000..5315c6aae9372 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Suit.java @@ -0,0 +1,51 @@ +/** + * Autogenerated by Thrift Compiler (0.9.2) + * + * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + * @generated + */ +package org.apache.spark.sql.parquet.test.thrift; + + +import java.util.Map; +import java.util.HashMap; +import org.apache.thrift.TEnum; + +public enum Suit implements org.apache.thrift.TEnum { + SPADES(0), + HEARTS(1), + DIAMONDS(2), + CLUBS(3); + + private final int value; + + private Suit(int value) { + this.value = value; + } + + /** + * Get the integer value of this enum value, as defined in the Thrift IDL. + */ + public int getValue() { + return value; + } + + /** + * Find a the enum type by its integer value, as defined in the Thrift IDL. + * @return null if the value is not found. + */ + public static Suit findByValue(int value) { + switch (value) { + case 0: + return SPADES; + case 1: + return HEARTS; + case 2: + return DIAMONDS; + case 3: + return CLUBS; + default: + return null; + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala new file mode 100644 index 0000000000000..bfa427349ff6a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.nio.ByteBuffer +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.avro.AvroParquetWriter + +import org.apache.spark.sql.parquet.test.avro.{Nested, ParquetAvroCompat} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{Row, SQLContext} + +class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { + import ParquetCompatibilityTest._ + + override val sqlContext: SQLContext = TestSQLContext + + override protected def beforeAll(): Unit = { + super.beforeAll() + + val writer = + new AvroParquetWriter[ParquetAvroCompat]( + new Path(parquetStore.getCanonicalPath), + ParquetAvroCompat.getClassSchema) + + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + writer.close() + } + + test("Read Parquet file generated by parquet-avro") { + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(parquetStore.getCanonicalPath)} + """.stripMargin) + + checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes, + s"val_$i", + + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i: Integer), + nullable(i.toLong: java.lang.Long), + nullable(i.toFloat + 0.1f: java.lang.Float), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(s"val_$i".getBytes), + nullable(s"val_$i"), + + Seq.tabulate(3)(n => s"arr_${i + n}"), + Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } + + def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { + def nullable[T <: AnyRef] = makeNullable[T](i) _ + + def makeComplexColumn(i: Int): JMap[String, JList[Nested]] = { + mapAsJavaMap(Seq.tabulate(3) { n => + (i + n).toString -> seqAsJavaList(Seq.tabulate(3) { m => + Nested + .newBuilder() + .setNestedIntsColumn(seqAsJavaList(Seq.tabulate(3)(j => i + j + m))) + .setNestedStringColumn(s"val_${i + m}") + .build() + }) + }.toMap) + } + + ParquetAvroCompat + .newBuilder() + .setBoolColumn(i % 2 == 0) + .setIntColumn(i) + .setLongColumn(i.toLong * 10) + .setFloatColumn(i.toFloat + 0.1f) + .setDoubleColumn(i.toDouble + 0.2d) + .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes)) + .setStringColumn(s"val_$i") + + .setMaybeBoolColumn(nullable(i % 2 == 0: java.lang.Boolean)) + .setMaybeIntColumn(nullable(i: Integer)) + .setMaybeLongColumn(nullable(i.toLong: java.lang.Long)) + .setMaybeFloatColumn(nullable(i.toFloat + 0.1f: java.lang.Float)) + .setMaybeDoubleColumn(nullable(i.toDouble + 0.2d: java.lang.Double)) + .setMaybeBinaryColumn(nullable(ByteBuffer.wrap(s"val_$i".getBytes))) + .setMaybeStringColumn(nullable(s"val_$i")) + + .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}")) + .setStringToIntColumn( + mapAsJavaMap(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap)) + .setComplexColumn(makeComplexColumn(i)) + + .build() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala new file mode 100644 index 0000000000000..b4cdfd9e98f6f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet +import java.io.File + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.schema.MessageType +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.QueryTest +import org.apache.spark.util.Utils + +abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { + protected var parquetStore: File = _ + + override protected def beforeAll(): Unit = { + parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") + parquetStore.delete() + } + + override protected def afterAll(): Unit = { + Utils.deleteRecursively(parquetStore) + } + + def readParquetSchema(path: String): MessageType = { + val fsPath = new Path(path) + val fs = fsPath.getFileSystem(configuration) + val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot(_.getPath.getName.startsWith("_")) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) + footers.head.getParquetMetadata.getFileMetaData.getSchema + } +} + +object ParquetCompatibilityTest { + def makeNullable[T <: AnyRef](i: Int)(f: => T): T = { + if (i % 3 == 0) null.asInstanceOf[T] else f + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala new file mode 100644 index 0000000000000..d22066cabc567 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.nio.ByteBuffer +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.thrift.ThriftParquetWriter + +import org.apache.spark.sql.parquet.test.thrift.{Nested, ParquetThriftCompat, Suit} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{Row, SQLContext} + +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest { + import ParquetCompatibilityTest._ + + override val sqlContext: SQLContext = TestSQLContext + + override protected def beforeAll(): Unit = { + super.beforeAll() + + val writer = + new ThriftParquetWriter[ParquetThriftCompat]( + new Path(parquetStore.getCanonicalPath), + classOf[ParquetThriftCompat], + CompressionCodecName.SNAPPY) + + (0 until 10).foreach(i => writer.write(makeParquetThriftCompat(i))) + writer.close() + } + + test("Read Parquet file generated by parquet-thrift") { + logInfo( + s"""Schema of the Parquet file written by parquet-thrift: + |${readParquetSchema(parquetStore.getCanonicalPath)} + """.stripMargin) + + checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + Row( + i % 2 == 0, + i.toByte, + (i + 1).toShort, + i + 2, + i.toLong * 10, + i.toDouble + 0.2d, + // Thrift `BINARY` values are actually unencoded `STRING` values, and thus are always + // treated as `BINARY (UTF8)` in parquet-thrift, since parquet-thrift always assume + // Thrift `STRING`s are encoded using UTF-8. + s"val_$i", + s"val_$i", + // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings + Suit.values()(i % 4).name(), + + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i.toByte: java.lang.Byte), + nullable((i + 1).toShort: java.lang.Short), + nullable(i + 2: Integer), + nullable((i * 10).toLong: java.lang.Long), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(s"val_$i"), + nullable(s"val_$i"), + nullable(Suit.values()(i % 4).name()), + + Seq.tabulate(3)(n => s"arr_${i + n}"), + // Thrift `SET`s are converted to Parquet `LIST`s + Seq(i), + Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap, + Seq.tabulate(3) { n => + (i + n) -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } + + def makeParquetThriftCompat(i: Int): ParquetThriftCompat = { + def makeComplexColumn(i: Int): JMap[Integer, JList[Nested]] = { + mapAsJavaMap(Seq.tabulate(3) { n => + (i + n: Integer) -> seqAsJavaList(Seq.tabulate(3) { m => + new Nested( + seqAsJavaList(Seq.tabulate(3)(j => i + j + m)), + s"val_${i + m}") + }) + }.toMap) + } + + val value = + new ParquetThriftCompat( + i % 2 == 0, + i.toByte, + (i + 1).toShort, + i + 2, + i.toLong * 10, + i.toDouble + 0.2d, + ByteBuffer.wrap(s"val_$i".getBytes), + s"val_$i", + Suit.values()(i % 4), + + seqAsJavaList(Seq.tabulate(3)(n => s"arr_${i + n}")), + setAsJavaSet(Set(i)), + mapAsJavaMap(Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap), + makeComplexColumn(i)) + + if (i % 3 == 0) { + value + } else { + value + .setMaybeBoolColumn(i % 2 == 0) + .setMaybeByteColumn(i.toByte) + .setMaybeShortColumn((i + 1).toShort) + .setMaybeIntColumn(i + 2) + .setMaybeLongColumn(i.toLong * 10) + .setMaybeDoubleColumn(i.toDouble + 0.2d) + .setMaybeBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes)) + .setMaybeStringColumn(s"val_$i") + .setMaybeEnumColumn(Suit.values()(i % 4)) + } + } +} diff --git a/sql/core/src/test/scripts/gen-code.sh b/sql/core/src/test/scripts/gen-code.sh new file mode 100755 index 0000000000000..5d8d8ad08555c --- /dev/null +++ b/sql/core/src/test/scripts/gen-code.sh @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +thrift\ + --gen java\ + -out $BASEDIR/gen-java\ + $BASEDIR/thrift/parquet-compat.thrift + +avro-tools idl $BASEDIR/avro/parquet-compat.avdl > $BASEDIR/avro/parquet-compat.avpr +avro-tools compile -string protocol $BASEDIR/avro/parquet-compat.avpr $BASEDIR/gen-java diff --git a/sql/core/src/test/thrift/parquet-compat.thrift b/sql/core/src/test/thrift/parquet-compat.thrift new file mode 100644 index 0000000000000..fa5ed8c62306a --- /dev/null +++ b/sql/core/src/test/thrift/parquet-compat.thrift @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace java org.apache.spark.sql.parquet.test.thrift + +enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS +} + +struct Nested { + 1: required list nestedIntsColumn; + 2: required string nestedStringColumn; +} + +/** + * This is a test struct for testing parquet-thrift compatibility. + */ +struct ParquetThriftCompat { + 1: required bool boolColumn; + 2: required byte byteColumn; + 3: required i16 shortColumn; + 4: required i32 intColumn; + 5: required i64 longColumn; + 6: required double doubleColumn; + 7: required binary binaryColumn; + 8: required string stringColumn; + 9: required Suit enumColumn + + 10: optional bool maybeBoolColumn; + 11: optional byte maybeByteColumn; + 12: optional i16 maybeShortColumn; + 13: optional i32 maybeIntColumn; + 14: optional i64 maybeLongColumn; + 15: optional double maybeDoubleColumn; + 16: optional binary maybeBinaryColumn; + 17: optional string maybeStringColumn; + 18: optional Suit maybeEnumColumn; + + 19: required list stringsColumn; + 20: required set intSetColumn; + 21: required map intToStringColumn; + 22: required map> complexColumn; +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala new file mode 100644 index 0000000000000..bb5f1febe9ad4 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.{Row, SQLConf, SQLContext} + +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { + import ParquetCompatibilityTest.makeNullable + + override val sqlContext: SQLContext = TestHive + + override protected def beforeAll(): Unit = { + super.beforeAll() + + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + withTempTable("data") { + sqlContext.sql( + s"""CREATE TABLE parquet_compat( + | bool_column BOOLEAN, + | byte_column TINYINT, + | short_column SMALLINT, + | int_column INT, + | long_column BIGINT, + | float_column FLOAT, + | double_column DOUBLE, + | + | strings_column ARRAY, + | int_to_string_column MAP + |) + |STORED AS PARQUET + |LOCATION '${parquetStore.getCanonicalPath}' + """.stripMargin) + + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + } + } + } + + override protected def afterAll(): Unit = { + sqlContext.sql("DROP TABLE parquet_compat") + } + + test("Read Parquet file generated by parquet-hive") { + logInfo( + s"""Schema of the Parquet file written by parquet-hive: + |${readParquetSchema(parquetStore.getCanonicalPath)} + """.stripMargin) + + // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. + // Have to assume all BINARY values are strings here. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), makeRows) + } + } + + def makeRows: Seq[Row] = { + (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + Row( + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i.toByte: java.lang.Byte), + nullable((i + 1).toShort: java.lang.Short), + nullable(i + 2: Integer), + nullable(i.toLong * 10: java.lang.Long), + nullable(i.toFloat + 0.1f: java.lang.Float), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(Seq.tabulate(3)(n => s"arr_${i + n}")), + nullable(Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c2e09800933b5..9d79a4b007d66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,14 +21,16 @@ import java.io.File import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql._ import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SQLConf, SaveMode} import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. @@ -685,6 +687,31 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest { sql("drop table spark_6016_fix") } + + test("SPARK-8811: compatibility with array of struct in Hive") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("array_of_struct") { + val conf = Seq( + HiveContext.CONVERT_METASTORE_PARQUET.key -> "false", + SQLConf.PARQUET_BINARY_AS_STRING.key -> "true", + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "true") + + withSQLConf(conf: _*) { + sql( + s"""CREATE TABLE array_of_struct + |STORED AS PARQUET LOCATION '$path' + |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) + """.stripMargin) + + checkAnswer( + sqlContext.read.parquet(path), + Row("1st", "2nd", Seq(Row("val_a", "val_b")))) + } + } + } + } } class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { @@ -762,7 +789,9 @@ class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + override def sqlContext: SQLContext = TestHive + var partitionedTableDir: File = null var normalTableDir: File = null var partitionedTableDirWithKey: File = null From 381cb161ba4e3a30f2da3c4ef4ee19869d51f101 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 8 Jul 2015 16:21:28 -0700 Subject: [PATCH 198/274] [SPARK-8068] [MLLIB] Add confusionMatrix method at class MulticlassMetrics in pyspark/mllib Add confusionMatrix method at class MulticlassMetrics in pyspark/mllib Author: Yanbo Liang Closes #7286 from yanboliang/spark-8068 and squashes the following commits: 6109fe1 [Yanbo Liang] Add confusionMatrix method at class MulticlassMetrics in pyspark/mllib --- python/pyspark/mllib/evaluation.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index c5cf3a4e7ff22..f21403707e12a 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -152,6 +152,10 @@ class MulticlassMetrics(JavaModelWrapper): >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) >>> metrics = MulticlassMetrics(predictionAndLabels) + >>> metrics.confusionMatrix().toArray() + array([[ 2., 1., 1.], + [ 1., 3., 0.], + [ 0., 0., 1.]]) >>> metrics.falsePositiveRate(0.0) 0.2... >>> metrics.precision(1.0) @@ -186,6 +190,13 @@ def __init__(self, predictionAndLabels): java_model = java_class(df._jdf) super(MulticlassMetrics, self).__init__(java_model) + def confusionMatrix(self): + """ + Returns confusion matrix: predicted classes are in columns, + they are ordered by class label ascending, as in "labels". + """ + return self.call("confusionMatrix") + def truePositiveRate(self, label): """ Returns true positive rate for a given label (category). From 8c32b2e870c7c250a63e838718df833edf6dea07 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 8 Jul 2015 16:27:11 -0700 Subject: [PATCH 199/274] [SPARK-8877] [MLLIB] Public API for association rule generation Adds FPGrowth.generateAssociationRules to public API for generating association rules after mining frequent itemsets. Author: Feynman Liang Closes #7271 from feynmanliang/SPARK-8877 and squashes the following commits: 83b8baf [Feynman Liang] Add API Doc 867abff [Feynman Liang] Add FPGrowth.generateAssociationRules and change access modifiers for AssociationRules --- .../spark/mllib/fpm/AssociationRules.scala | 5 ++- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 11 ++++- .../spark/mllib/fpm/FPGrowthSuite.scala | 42 +++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 4a0f842f3338d..7e2bbfe31c1b7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD * association rules which have a single item as the consequent. */ @Experimental -class AssociationRules private ( +class AssociationRules private[fpm] ( private var minConfidence: Double) extends Logging with Serializable { /** @@ -45,6 +45,7 @@ class AssociationRules private ( * Sets the minimal confidence (default: `0.8`). */ def setMinConfidence(minConfidence: Double): this.type = { + require(minConfidence >= 0.0 && minConfidence <= 1.0) this.minConfidence = minConfidence this } @@ -91,7 +92,7 @@ object AssociationRules { * @tparam Item item type */ @Experimental - class Rule[Item] private[mllib] ( + class Rule[Item] private[fpm] ( val antecedent: Array[Item], val consequent: Array[Item], freqUnion: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 0da59e812d5f9..9cb9a00dbd9c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -40,7 +40,16 @@ import org.apache.spark.storage.StorageLevel * @tparam Item item type */ @Experimental -class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable +class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { + /** + * Generates association rules for the [[Item]]s in [[freqItemsets]]. + * @param confidence minimal confidence of the rules produced + */ + def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { + val associationRules = new AssociationRules(confidence) + associationRules.run(freqItemsets) + } +} /** * :: Experimental :: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index ddc296a428907..4a9bfdb348d9f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -132,6 +132,48 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model1.freqItemsets.count() === 625) } + test("FP-Growth String type association rule generation") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + /* Verify results using the `R` code: + transactions = as(sapply( + list("r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p"), + FUN=function(x) strsplit(x," ",fixed=TRUE)), + "transactions") + ars = apriori(transactions, + parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + > nrow(arsDF) + [1] 23 + > sum(arsDF$confidence == 1) + [1] 23 + */ + val rules = (new FPGrowth()) + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + .generateAssociationRules(0.9) + .collect() + + assert(rules.size === 23) + assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + } + test("FP-Growth using Int type") { val transactions = Seq( "1 2 3", From f472b8cdc00839780dc79be0bbe53a098cde230c Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 8 Jul 2015 16:32:00 -0700 Subject: [PATCH 200/274] [SPARK-5016] [MLLIB] Distribute GMM mixture components to executors Distribute expensive portions of computation for Gaussian mixture components (in particular, pre-computation of `MultivariateGaussian.rootSigmaInv`, the inverse covariance matrix and covariance determinant) across executors. Repost of PR#4654. Notes for reviewers: * What should be the policy for when to distribute computation. Always? When numClusters > threshold? User-specified param? TODO: * Performance testing and comparison for large number of clusters Author: Feynman Liang Closes #7166 from feynmanliang/GMM_parallel_mixtures and squashes the following commits: 4f351fa [Feynman Liang] Update heuristic and scaladoc 5ea947e [Feynman Liang] Fix parallelization logic 00eb7db [Feynman Liang] Add helper method for GMM's M step, remove distributeGaussians flag e7c8127 [Feynman Liang] Add distributeGaussians flag and tests 1da3c7f [Feynman Liang] Distribute mixtures --- .../mllib/clustering/GaussianMixture.scala | 44 +++++++++++++++---- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index fc509d2ba1470..e459367333d26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -140,6 +140,10 @@ class GaussianMixture private ( // Get length of the input vectors val d = breezeData.first().length + // Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when + // d > 25 except for when k is very small + val distributeGaussians = ((k - 1.0) / k) * d > 25 + // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise // we start with uniform weights, a random mean from the data, and @@ -171,14 +175,25 @@ class GaussianMixture private ( // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum - var i = 0 - while (i < k) { - val mu = sums.means(i) / sums.weights(i) - BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu), - Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) - weights(i) = sums.weights(i) / sumWeights - gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) - i = i + 1 + + if (distributeGaussians) { + val numPartitions = math.min(k, 1024) + val tuples = + Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i))) + val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) => + updateWeightsAndGaussians(mean, sigma, weight, sumWeights) + }.collect.unzip + Array.copy(ws, 0, weights, 0, ws.length) + Array.copy(gs, 0, gaussians, 0, gs.length) + } else { + var i = 0 + while (i < k) { + val (weight, gaussian) = + updateWeightsAndGaussians(sums.means(i), sums.sigmas(i), sums.weights(i), sumWeights) + weights(i) = weight + gaussians(i) = gaussian + i = i + 1 + } } llhp = llh // current becomes previous @@ -192,6 +207,19 @@ class GaussianMixture private ( /** Java-friendly version of [[run()]] */ def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) + private def updateWeightsAndGaussians( + mean: BDV[Double], + sigma: BreezeMatrix[Double], + weight: Double, + sumWeights: Double): (Double, MultivariateGaussian) = { + val mu = (mean /= weight) + BLAS.syr(-weight, Vectors.fromBreeze(mu), + Matrices.fromBreeze(sigma).asInstanceOf[DenseMatrix]) + val newWeight = weight / sumWeights + val newGaussian = new MultivariateGaussian(mu, sigma / weight) + (newWeight, newGaussian) + } + /** Average of dense breeze vectors */ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) From 2a4f88b6c16f2991e63b17c0e103bcd79f04dbbc Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 8 Jul 2015 18:09:39 -0700 Subject: [PATCH 201/274] [SPARK-8914][SQL] Remove RDDApi As rxin suggested in #7298 , we should consider to remove `RDDApi`. Author: Kousuke Saruta Closes #7302 from sarutak/remove-rddapi and squashes the following commits: e495d35 [Kousuke Saruta] Fixed mima cb7ebb9 [Kousuke Saruta] Removed overriding RDDApi --- project/MimaExcludes.scala | 5 ++ .../org/apache/spark/sql/DataFrame.scala | 39 ++++++----- .../scala/org/apache/spark/sql/RDDApi.scala | 67 ------------------- 3 files changed, 24 insertions(+), 87 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7346d804632bc..57a86bf8deb64 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,7 +70,12 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrix.numActives") + ) ++ Seq( + // SPARK-8914 Remove RDDApi + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.RDDApi") ) + case v if v.startsWith("1.4") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f33e19a0cb7dd..eeefc85255d14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -115,8 +115,7 @@ private[sql] object DataFrame { @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) - extends RDDApi[Row] with Serializable { + @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) extends Serializable { /** * A constructor that automatically analyzes the logical plan. @@ -1320,14 +1319,14 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - override def first(): Row = head() + def first(): Row = head() /** * Returns a new RDD by applying a function to all rows of this DataFrame. * @group rdd * @since 1.3.0 */ - override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) + def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) /** * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], @@ -1335,14 +1334,14 @@ class DataFrame private[sql]( * @group rdd * @since 1.3.0 */ - override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) /** * Returns a new RDD by applying a function to each partition of this DataFrame. * @group rdd * @since 1.3.0 */ - override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { rdd.mapPartitions(f) } @@ -1351,49 +1350,49 @@ class DataFrame private[sql]( * @group rdd * @since 1.3.0 */ - override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + def foreach(f: Row => Unit): Unit = rdd.foreach(f) /** * Applies a function f to each partition of this [[DataFrame]]. * @group rdd * @since 1.3.0 */ - override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) /** * Returns the first `n` rows in the [[DataFrame]]. * @group action * @since 1.3.0 */ - override def take(n: Int): Array[Row] = head(n) + def take(n: Int): Array[Row] = head(n) /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. * @group action * @since 1.3.0 */ - override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() + def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() /** * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. * @group action * @since 1.3.0 */ - override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*) + def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*) /** * Returns the number of rows in the [[DataFrame]]. * @group action * @since 1.3.0 */ - override def count(): Long = groupBy().count().collect().head.getLong(0) + def count(): Long = groupBy().count().collect().head.getLong(0) /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. * @group rdd * @since 1.3.0 */ - override def repartition(numPartitions: Int): DataFrame = { + def repartition(numPartitions: Int): DataFrame = { Repartition(numPartitions, shuffle = true, logicalPlan) } @@ -1405,7 +1404,7 @@ class DataFrame private[sql]( * @group rdd * @since 1.4.0 */ - override def coalesce(numPartitions: Int): DataFrame = { + def coalesce(numPartitions: Int): DataFrame = { Repartition(numPartitions, shuffle = false, logicalPlan) } @@ -1415,13 +1414,13 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - override def distinct(): DataFrame = dropDuplicates() + def distinct(): DataFrame = dropDuplicates() /** * @group basic * @since 1.3.0 */ - override def persist(): this.type = { + def persist(): this.type = { sqlContext.cacheManager.cacheQuery(this) this } @@ -1430,13 +1429,13 @@ class DataFrame private[sql]( * @group basic * @since 1.3.0 */ - override def cache(): this.type = persist() + def cache(): this.type = persist() /** * @group basic * @since 1.3.0 */ - override def persist(newLevel: StorageLevel): this.type = { + def persist(newLevel: StorageLevel): this.type = { sqlContext.cacheManager.cacheQuery(this, None, newLevel) this } @@ -1445,7 +1444,7 @@ class DataFrame private[sql]( * @group basic * @since 1.3.0 */ - override def unpersist(blocking: Boolean): this.type = { + def unpersist(blocking: Boolean): this.type = { sqlContext.cacheManager.tryUncacheQuery(this, blocking) this } @@ -1454,7 +1453,7 @@ class DataFrame private[sql]( * @group basic * @since 1.3.0 */ - override def unpersist(): this.type = unpersist(blocking = false) + def unpersist(): this.type = unpersist(blocking = false) ///////////////////////////////////////////////////////////////////////////// // I/O diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala deleted file mode 100644 index 63dbab19947c0..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import scala.reflect.ClassTag - -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - - -/** - * An internal interface defining the RDD-like methods for [[DataFrame]]. - * Please use [[DataFrame]] directly, and do NOT use this. - */ -private[sql] trait RDDApi[T] { - - def cache(): this.type - - def persist(): this.type - - def persist(newLevel: StorageLevel): this.type - - def unpersist(): this.type - - def unpersist(blocking: Boolean): this.type - - def map[R: ClassTag](f: T => R): RDD[R] - - def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] - - def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] - - def foreach(f: T => Unit): Unit - - def foreachPartition(f: Iterator[T] => Unit): Unit - - def take(n: Int): Array[T] - - def collect(): Array[T] - - def collectAsList(): java.util.List[T] - - def count(): Long - - def first(): T - - def repartition(numPartitions: Int): DataFrame - - def coalesce(numPartitions: Int): DataFrame - - def distinct: DataFrame -} From 74d8d3d928cc9a7386b68588ac89ae042847d146 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 8 Jul 2015 18:22:53 -0700 Subject: [PATCH 202/274] [SPARK-8450] [SQL] [PYSARK] cleanup type converter for Python DataFrame This PR fixes the converter for Python DataFrame, especially for DecimalType Closes #7106 Author: Davies Liu Closes #7131 from davies/decimal_python and squashes the following commits: 4d3c234 [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 20531d6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 7d73168 [Davies Liu] fix conflit 6cdd86a [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 7104e97 [Davies Liu] improve type infer 9cd5a21 [Davies Liu] run python tests with SPARK_PREPEND_CLASSES 829a05b [Davies Liu] fix UDT in python c99e8c5 [Davies Liu] fix mima c46814a [Davies Liu] convert decimal for Python DataFrames --- .../apache/spark/mllib/linalg/Matrices.scala | 10 +- .../apache/spark/mllib/linalg/Vectors.scala | 16 +--- project/MimaExcludes.scala | 5 +- python/pyspark/sql/tests.py | 13 +++ python/pyspark/sql/types.py | 4 + python/run-tests.py | 3 +- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 28 +----- .../spark/sql/execution/pythonUDFs.scala | 95 ++++++++++--------- 9 files changed, 84 insertions(+), 94 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 75e7004464af9..0df07663405a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.Row -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ /** * Trait for a local matrix. @@ -147,7 +147,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { )) } - override def serialize(obj: Any): Row = { + override def serialize(obj: Any): InternalRow = { val row = new GenericMutableRow(7) obj match { case sm: SparseMatrix => @@ -173,9 +173,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def deserialize(datum: Any): Matrix = { datum match { - // TODO: something wrong with UDT serialization, should never happen. - case m: Matrix => m - case row: Row => + case row: InternalRow => require(row.length == 7, s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7") val tpe = row.getByte(0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c9c27425d2877..e048b01d92462 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -28,7 +28,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ @@ -175,7 +175,7 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) } - override def serialize(obj: Any): Row = { + override def serialize(obj: Any): InternalRow = { obj match { case SparseVector(size, indices, values) => val row = new GenericMutableRow(4) @@ -191,17 +191,12 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { row.setNullAt(2) row.update(3, values.toSeq) row - // TODO: There are bugs in UDT serialization because we don't have a clear separation between - // TODO: internal SQL types and language specific types (including UDT). UDT serialize and - // TODO: deserialize may get called twice. See SPARK-7186. - case row: Row => - row } } override def deserialize(datum: Any): Vector = { datum match { - case row: Row => + case row: InternalRow => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") val tpe = row.getByte(0) @@ -215,11 +210,6 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val values = row.getAs[Iterable[Double]](3).toArray new DenseVector(values) } - // TODO: There are bugs in UDT serialization because we don't have a clear separation between - // TODO: internal SQL types and language specific types (including UDT). UDT serialize and - // TODO: deserialize may get called twice. See SPARK-7186. - case v: Vector => - v } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 57a86bf8deb64..821aadd477ef3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -63,7 +63,10 @@ object MimaExcludes { // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), // Parquet support is considered private. - excludePackage("org.apache.spark.sql.parquet") + excludePackage("org.apache.spark.sql.parquet"), + // local function inside a method + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1") ) ++ Seq( // SPARK-8479 Add numNonzeros and numActives to Matrix. ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 333378c7f1854..66827d48850d9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -700,6 +700,19 @@ def test_time_with_timezone(self): self.assertTrue(now - now1 < datetime.timedelta(0.001)) self.assertTrue(now - utcnow1 < datetime.timedelta(0.001)) + def test_decimal(self): + from decimal import Decimal + schema = StructType([StructField("decimal", DecimalType(10, 5))]) + df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema) + row = df.select(df.decimal + 1).first() + self.assertEqual(row[0], Decimal("4.14159")) + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.parquet(tmpPath) + df2 = self.sqlCtx.read.parquet(tmpPath) + row = df2.first() + self.assertEqual(row[0], Decimal("3.14159")) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 160df40d65cc1..7e64cb0b54dba 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1069,6 +1069,10 @@ def _verify_type(obj, dataType): if obj is None: return + # StringType can work with any types + if isinstance(dataType, StringType): + return + if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError("%r is not an instance of type %r" % (obj, dataType)) diff --git a/python/run-tests.py b/python/run-tests.py index 7638854def2e8..cc560779373b3 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -72,7 +72,8 @@ def print_red(text): def run_individual_python_test(test_name, pyspark_python): - env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} + env = dict(os.environ) + env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}) LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index eeefc85255d14..d9f987ae0252f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1549,8 +1549,8 @@ class DataFrame private[sql]( * Converts a JavaRDD to a PythonRDD. */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + val structType = schema // capture it for closure + val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)).toJavaRDD() SerDeUtil.javaToPython(jrdd) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 079f31ab8fe6d..477dea9164726 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1044,33 +1044,7 @@ class SQLContext(@transient val sparkContext: SparkContext) rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - def needsConversion(dataType: DataType): Boolean = dataType match { - case ByteType => true - case ShortType => true - case LongType => true - case FloatType => true - case DateType => true - case TimestampType => true - case StringType => true - case ArrayType(_, _) => true - case MapType(_, _, _) => true - case StructType(_) => true - case udt: UserDefinedType[_] => needsConversion(udt.sqlType) - case other => false - } - - val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { - rdd.map(m => m.zip(schema.fields).map { - case (value, field) => EvaluatePython.fromJava(value, field.dataType) - }) - } else { - rdd - } - - val rowRdd = convertedRdd.mapPartitions { iter => - iter.map { m => new GenericInternalRow(m): InternalRow} - } - + val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 6946e798b71b0..1c8130b07c7fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -24,20 +24,19 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} -import org.apache.spark.{Accumulator, Logging => SparkLogging} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Accumulator, Logging => SparkLogging} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -125,59 +124,86 @@ object EvaluatePython { new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) /** - * Helper for converting a Scala object to a java suitable for pyspark serialization. + * Helper for converting from Catalyst type to java type suitable for Pyrolite. */ def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (row: Row, struct: StructType) => + case (row: InternalRow, struct: StructType) => val fields = struct.fields.map(field => field.dataType) - row.toSeq.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray + rowToArray(row, fields) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava - case (list: JList[_], array: ArrayType) => - list.map(x => toJava(x, array.elementType)).asJava - case (arr, array: ArrayType) if arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) }.asJava - case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) + + case (d: Decimal, _) => d.toJavaBigDecimal + case (s: UTF8String, StringType) => s.toString - // Pyrolite can handle Timestamp and Decimal case (other, _) => other } /** * Convert Row into Java Array (for pickled into Python) */ - def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { + def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = { // TODO: this is slow! row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray } - // Converts value to the type specified by the data type. - // Because Python does not have data types for TimestampType, FloatType, ShortType, and - // ByteType, we need to explicitly convert values in columns of these data types to the desired - // JVM data types. + /** + * Converts `obj` to the type specified by the data type, or returns null if the type of obj is + * unexpected. Because Python doesn't enforce the type. + */ def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - // TODO: We should check nullable case (null, _) => null + case (c: Boolean, BooleanType) => c + + case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte + + case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + + case (c: Int, IntegerType) => c + case (c: Long, IntegerType) => c.toInt + + case (c: Int, LongType) => c.toLong + case (c: Long, LongType) => c + + case (c: Double, FloatType) => c.toFloat + + case (c: Double, DoubleType) => c + + case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c) + + case (c: Int, DateType) => c + + case (c: Long, TimestampType) => c + + case (c: String, StringType) => UTF8String.fromString(c) + case (c, StringType) => + // If we get here, c is not a string. Call toString on it. + UTF8String.fromString(c.toString) + + case (c: String, BinaryType) => c.getBytes("utf-8") + case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => fromJava(e, elementType)}: Seq[Any] + c.map { e => fromJava(e, elementType)}.toSeq case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any] + c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) @@ -188,30 +214,11 @@ object EvaluatePython { case (e, f) => fromJava(e, f.dataType) }) - case (c: java.util.Calendar, DateType) => - DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) - - case (c: java.util.Calendar, TimestampType) => - c.getTimeInMillis * 10000L - case (t: java.sql.Timestamp, TimestampType) => - DateTimeUtils.fromJavaTimestamp(t) - - case (_, udt: UserDefinedType[_]) => - fromJava(obj, udt.sqlType) - - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort - case (c: Long, IntegerType) => c.toInt - case (c: Int, LongType) => c.toLong - case (c: Double, FloatType) => c.toFloat - case (c: String, StringType) => UTF8String.fromString(c) - case (c, StringType) => - // If we get here, c is not a string. Call toString on it. - UTF8String.fromString(c.toString) + case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) - case (c, _) => c + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + case (c, _) => null } } From 28fa01e2ba146e823489f6d81c5eb3a76b20c71f Mon Sep 17 00:00:00 2001 From: Jonathan Alter Date: Thu, 9 Jul 2015 03:28:51 +0100 Subject: [PATCH 203/274] [SPARK-8927] [DOCS] Format wrong for some config descriptions A couple descriptions were not inside `` and were being displayed immediately under the section title instead of in their row. Author: Jonathan Alter Closes #7292 from jonalter/docs-config and squashes the following commits: 5ce1570 [Jonathan Alter] [DOCS] Format wrong for some config descriptions --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index bebaf6f62e90a..892c02b27df32 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1007,9 +1007,9 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.numRetries 3 + Number of times to retry before an RPC task gives up. An RPC task will run at most times of this number. - @@ -1029,8 +1029,8 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.lookupTimeout 120s - Duration for an RPC remote endpoint lookup operation to wait before timing out. + Duration for an RPC remote endpoint lookup operation to wait before timing out. From a290814877308c6fa9b0f78b1a81145db7651ca4 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 8 Jul 2015 20:20:17 -0700 Subject: [PATCH 204/274] [SPARK-8866][SQL] use 1us precision for timestamp type JIRA: https://issues.apache.org/jira/browse/SPARK-8866 Author: Yijie Shen Closes #7283 from yijieshen/micro_timestamp and squashes the following commits: dc735df [Yijie Shen] update CastSuite to avoid round error 714eaea [Yijie Shen] add timestamp_udf into blacklist due to precision lose c3ca2f4 [Yijie Shen] fix unhandled case in CurrentTimestamp 8d4aa6b [Yijie Shen] use 1us precision for timestamp type --- python/pyspark/sql/types.py | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 18 ++++----- .../expressions/datetimeFunctions.scala | 2 +- .../sql/catalyst/util/DateTimeUtils.scala | 38 +++++++++---------- .../sql/catalyst/expressions/CastSuite.scala | 10 ++--- .../catalyst/util/DateTimeUtilsSuite.scala | 8 ++-- .../apache/spark/sql/json/JacksonParser.scala | 4 +- .../org/apache/spark/sql/json/JsonRDD.scala | 6 +-- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../execution/HiveCompatibilitySuite.scala | 6 +-- .../spark/sql/hive/HiveInspectors.scala | 4 +- 11 files changed, 50 insertions(+), 50 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 7e64cb0b54dba..fecfe6d71e9a7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -775,7 +775,7 @@ def to_posix_timstamp(dt): if dt: seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple())) - return int(seconds * 1e7 + dt.microsecond * 10) + return int(seconds * 1e6 + dt.microsecond) return to_posix_timstamp else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 662ceeca7782d..567feca7136f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -186,7 +186,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 10000) + buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 1000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -207,16 +207,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } private[this] def decimalToTimestamp(d: Decimal): Long = { - (d.toBigDecimal * 10000000L).longValue() + (d.toBigDecimal * 1000000L).longValue() } - // converting milliseconds to 100ns - private[this] def longToTimestamp(t: Long): Long = t * 10000L - // converting 100ns to seconds - private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 10000000L).toLong - // converting 100ns to seconds in double + // converting milliseconds to us + private[this] def longToTimestamp(t: Long): Long = t * 1000L + // converting us to seconds + private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 1000000L).toLong + // converting us to seconds in double private[this] def timestampToDouble(ts: Long): Double = { - ts / 10000000.0 + ts / 1000000.0 } // DateConverter @@ -229,7 +229,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 10000L)) + buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 1000L)) // Hive throws this exception as a Semantic Exception // It is never possible to compare result when hive return with exception, // so we can return null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index a492b966a5e31..dd5ec330a771b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -51,6 +51,6 @@ case class CurrentTimestamp() extends LeafExpression { override def dataType: DataType = TimestampType override def eval(input: InternalRow): Any = { - System.currentTimeMillis() * 10000L + System.currentTimeMillis() * 1000L } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 4269ad5d56737..c1ddee3ef0230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -34,8 +34,8 @@ object DateTimeUtils { // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5 final val SECONDS_PER_DAY = 60 * 60 * 24L - final val HUNDRED_NANOS_PER_SECOND = 1000L * 1000L * 10L - final val NANOS_PER_SECOND = HUNDRED_NANOS_PER_SECOND * 100 + final val MICROS_PER_SECOND = 1000L * 1000L + final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. @@ -77,8 +77,8 @@ object DateTimeUtils { threadLocalDateFormat.get.format(toJavaDate(days)) // Converts Timestamp to string according to Hive TimestampWritable convention. - def timestampToString(num100ns: Long): String = { - val ts = toJavaTimestamp(num100ns) + def timestampToString(us: Long): String = { + val ts = toJavaTimestamp(us) val timestampString = ts.toString val formatted = threadLocalTimestampFormat.get.format(ts) @@ -132,52 +132,52 @@ object DateTimeUtils { } /** - * Returns a java.sql.Timestamp from number of 100ns since epoch. + * Returns a java.sql.Timestamp from number of micros since epoch. */ - def toJavaTimestamp(num100ns: Long): Timestamp = { + def toJavaTimestamp(us: Long): Timestamp = { // setNanos() will overwrite the millisecond part, so the milliseconds should be // cut off at seconds - var seconds = num100ns / HUNDRED_NANOS_PER_SECOND - var nanos = num100ns % HUNDRED_NANOS_PER_SECOND + var seconds = us / MICROS_PER_SECOND + var micros = us % MICROS_PER_SECOND // setNanos() can not accept negative value - if (nanos < 0) { - nanos += HUNDRED_NANOS_PER_SECOND + if (micros < 0) { + micros += MICROS_PER_SECOND seconds -= 1 } val t = new Timestamp(seconds * 1000) - t.setNanos(nanos.toInt * 100) + t.setNanos(micros.toInt * 1000) t } /** - * Returns the number of 100ns since epoch from java.sql.Timestamp. + * Returns the number of micros since epoch from java.sql.Timestamp. */ def fromJavaTimestamp(t: Timestamp): Long = { if (t != null) { - t.getTime() * 10000L + (t.getNanos().toLong / 100) % 10000L + t.getTime() * 1000L + (t.getNanos().toLong / 1000) % 1000L } else { 0L } } /** - * Returns the number of 100ns (hundred of nanoseconds) since epoch from Julian day + * Returns the number of microseconds since epoch from Julian day * and nanoseconds in a day */ def fromJulianDay(day: Int, nanoseconds: Long): Long = { // use Long to avoid rounding errors val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY - SECONDS_PER_DAY / 2 - seconds * HUNDRED_NANOS_PER_SECOND + nanoseconds / 100L + seconds * MICROS_PER_SECOND + nanoseconds / 1000L } /** - * Returns Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) + * Returns Julian day and nanoseconds in a day from the number of microseconds */ - def toJulianDay(num100ns: Long): (Int, Long) = { - val seconds = num100ns / HUNDRED_NANOS_PER_SECOND + SECONDS_PER_DAY / 2 + def toJulianDay(us: Long): (Int, Long) = { + val seconds = us / MICROS_PER_SECOND + SECONDS_PER_DAY / 2 val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH val secondsInDay = seconds % SECONDS_PER_DAY - val nanos = (num100ns % HUNDRED_NANOS_PER_SECOND) * 100L + val nanos = (us % MICROS_PER_SECOND) * 1000L (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 518961e38396f..919fdd470b79a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -293,15 +293,15 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from timestamp") { - val millis = 15 * 1000 + 2 - val seconds = millis * 1000 + 2 + val millis = 15 * 1000 + 3 + val seconds = millis * 1000 + 3 val ts = new Timestamp(millis) val tss = new Timestamp(seconds) checkEvaluation(cast(ts, ShortType), 15.toShort) checkEvaluation(cast(ts, IntegerType), 15) checkEvaluation(cast(ts, LongType), 15.toLong) - checkEvaluation(cast(ts, FloatType), 15.002f) - checkEvaluation(cast(ts, DoubleType), 15.002) + checkEvaluation(cast(ts, FloatType), 15.003f) + checkEvaluation(cast(ts, DoubleType), 15.003) checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) checkEvaluation(cast(cast(tss, IntegerType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) @@ -317,7 +317,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Decimal(1)) // A test for higher precision than millis - checkEvaluation(cast(cast(0.0000001, TimestampType), DoubleType), 0.0000001) + checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) checkEvaluation(cast(Double.NaN, TimestampType), null) checkEvaluation(cast(1.0 / 0.0, TimestampType), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 1d4a60c81efc5..f63ac191e7366 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -24,11 +24,11 @@ import org.apache.spark.SparkFunSuite class DateTimeUtilsSuite extends SparkFunSuite { - test("timestamp and 100ns") { + test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) - now.setNanos(100) + now.setNanos(1000) val ns = DateTimeUtils.fromJavaTimestamp(now) - assert(ns % 10000000L === 1) + assert(ns % 1000000L === 1) assert(DateTimeUtils.toJavaTimestamp(ns) === now) List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => @@ -38,7 +38,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { } } - test("100ns and julian day") { + test("us and julian day") { val (d, ns) = DateTimeUtils.toJulianDay(0) assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH) assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 4b8ab63b5ab39..381e7ed54428f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -67,10 +67,10 @@ private[sql] object JacksonParser { DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) case (VALUE_STRING, TimestampType) => - DateTimeUtils.stringToTime(parser.getText).getTime * 10000L + DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => - parser.getLongValue * 10000L + parser.getLongValue * 1000L case (_, StringType) => val writer = new ByteArrayOutputStream() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 01ba05cbd14f1..b392a51bf7dce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -401,9 +401,9 @@ private[sql] object JsonRDD extends Logging { private def toTimestamp(value: Any): Long = { value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 10000L - case value: java.lang.Long => value * 10000L - case value: java.lang.String => DateTimeUtils.stringToTime(value).getTime * 10000L + case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 1000L + case value: java.lang.Long => value * 1000L + case value: java.lang.String => DateTimeUtils.stringToTime(value).getTime * 1000L } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 69ab1c292d221..566a52dc1b784 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -326,7 +326,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(cal.get(Calendar.HOUR) === 11) assert(cal.get(Calendar.MINUTE) === 22) assert(cal.get(Calendar.SECOND) === 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543500) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543000) } test("test DATE types") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 415a81644c58f..c884c399281a8 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -254,9 +254,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // the answer is sensitive for jdk version "udf_java_method", - // Spark SQL use Long for TimestampType, lose the precision under 100ns + // Spark SQL use Long for TimestampType, lose the precision under 1us "timestamp_1", - "timestamp_2" + "timestamp_2", + "timestamp_udf" ) /** @@ -803,7 +804,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_comparison", "timestamp_lazy", "timestamp_null", - "timestamp_udf", "touch", "transform_ppr1", "transform_ppr2", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4cba17524af6c..a8f2ee37cb8ed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -267,7 +267,7 @@ private[hive] trait HiveInspectors { poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => val t = poi.getWritableConstantValue - t.getSeconds * 10000000L + t.getNanos / 100L + t.getSeconds * 1000000L + t.getNanos / 1000L case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantDoubleObjectInspector => @@ -332,7 +332,7 @@ private[hive] trait HiveInspectors { case x: DateObjectInspector => DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) case x: TimestampObjectInspector if x.preferWritable() => val t = x.getPrimitiveWritableObject(data) - t.getSeconds * 10000000L + t.getNanos / 100 + t.getSeconds * 1000000L + t.getNanos / 1000L case ti: TimestampObjectInspector => DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) case _ => pi.getPrimitiveJavaObject(data) From b55499a44ab74e33378211fb0d6940905d7c6318 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 8 Jul 2015 20:28:05 -0700 Subject: [PATCH 205/274] [SPARK-8932] Support copy() for UnsafeRows that do not use ObjectPools We call Row.copy() in many places throughout SQL but UnsafeRow currently throws UnsupportedOperationException when copy() is called. Supporting copying when ObjectPool is used may be difficult, since we may need to handle deep-copying of objects in the pool. In addition, this copy() method needs to produce a self-contained row object which may be passed around / buffered by downstream code which does not understand the UnsafeRow format. In the long run, we'll need to figure out how to handle the ObjectPool corner cases, but this may be unnecessary if other changes are made. Therefore, in order to unblock my sort patch (#6444) I propose that we support copy() for the cases where UnsafeRow does not use an ObjectPool and continue to throw UnsupportedOperationException when an ObjectPool is used. This patch accomplishes this by modifying UnsafeRow so that it knows the size of the row's backing data in order to be able to copy it into a byte array. Author: Josh Rosen Closes #7306 from JoshRosen/SPARK-8932 and squashes the following commits: 338e6bf [Josh Rosen] Support copy for UnsafeRows that do not use ObjectPools. --- .../UnsafeFixedWidthAggregationMap.java | 12 +++-- .../sql/catalyst/expressions/UnsafeRow.java | 32 +++++++++++- .../expressions/UnsafeRowConverter.scala | 10 +++- .../expressions/UnsafeRowConverterSuite.scala | 52 ++++++++++++++----- 4 files changed, 87 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 1e79f4b2e88e5..79d55b36dab01 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -120,9 +120,11 @@ public UnsafeFixedWidthAggregationMap( this.bufferPool = new ObjectPool(initialCapacity); InternalRow initRow = initProjection.apply(emptyRow); - this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; + int emptyBufferSize = bufferConverter.getSizeRequirement(initRow); + this.emptyBuffer = new byte[emptyBufferSize]; int writtenLength = bufferConverter.writeRow( - initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize, + bufferPool); assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; // re-use the empty buffer only when there is no object saved in pool. reuseEmptyBuffer = bufferPool.size() == 0; @@ -142,6 +144,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { groupingKey, groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET, + groupingKeySize, keyPool); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; @@ -157,7 +160,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { // There is some objects referenced by emptyBuffer, so generate a new one InternalRow initRow = initProjection.apply(emptyRow); bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - bufferPool); + groupingKeySize, bufferPool); } loc.putNewKey( groupingKeyConversionScratchSpace, @@ -175,6 +178,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { address.getBaseObject(), address.getBaseOffset(), bufferConverter.numFields(), + loc.getValueLength(), bufferPool ); return currentBuffer; @@ -214,12 +218,14 @@ public MapEntry next() { keyAddress.getBaseObject(), keyAddress.getBaseOffset(), keyConverter.numFields(), + loc.getKeyLength(), keyPool ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), bufferConverter.numFields(), + loc.getValueLength(), bufferPool ); return entry; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index aeb64b045812f..edb7202245289 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -68,6 +68,9 @@ public final class UnsafeRow extends MutableRow { /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; + /** The size of this row's backing data, in bytes) */ + private int sizeInBytes; + public int length() { return numFields; } /** The width of the null tracking bit set, in bytes */ @@ -95,14 +98,17 @@ public UnsafeRow() { } * @param baseObject the base object * @param baseOffset the offset within the base object * @param numFields the number of fields in this row + * @param sizeInBytes the size of this row's backing data, in bytes * @param pool the object pool to hold arbitrary objects */ - public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) { + public void pointTo( + Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) { assert numFields >= 0 : "numFields should >= 0"; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; + this.sizeInBytes = sizeInBytes; this.pool = pool; } @@ -336,9 +342,31 @@ public double getDouble(int i) { } } + /** + * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal + * byte array rather than referencing data stored in a data page. + *

+ * This method is only supported on UnsafeRows that do not use ObjectPools. + */ @Override public InternalRow copy() { - throw new UnsupportedOperationException(); + if (pool != null) { + throw new UnsupportedOperationException( + "Copy is not supported for UnsafeRows that use object pools"); + } else { + UnsafeRow rowCopy = new UnsafeRow(); + final byte[] rowDataCopy = new byte[sizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + rowDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeInBytes + ); + rowCopy.pointTo( + rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null); + return rowCopy; + } } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 1f395497a9839..6af5e6200e57b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { * @param row the row to convert * @param baseObject the base object of the destination address * @param baseOffset the base offset of the destination address + * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)` * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. */ - def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool) + def writeRow( + row: InternalRow, + baseObject: Object, + baseOffset: Long, + rowLengthInBytes: Int, + pool: ObjectPool): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool) if (writers.length > 0) { // zero-out the bitset diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 96d4e64ea344a..d00aeb4dfbf47 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -44,19 +44,32 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) + // We can copy UnsafeRows as long as they don't reference ObjectPools + val unsafeRowCopy = unsafeRow.copy() + assert(unsafeRowCopy.getLong(0) === 0) + assert(unsafeRowCopy.getLong(1) === 1) + assert(unsafeRowCopy.getInt(2) === 2) + unsafeRow.setLong(1, 3) assert(unsafeRow.getLong(1) === 3) unsafeRow.setInt(2, 4) assert(unsafeRow.getInt(2) === 4) + + // Mutating the original row should not have changed the copy + assert(unsafeRowCopy.getLong(0) === 0) + assert(unsafeRowCopy.getLong(1) === 1) + assert(unsafeRowCopy.getInt(2) === 2) } test("basic conversion with primitive, string and binary types") { @@ -73,12 +86,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = converter.writeRow( + row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() val pool = new ObjectPool(10) - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") assert(unsafeRow.get(2) === "World".getBytes) @@ -96,6 +111,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.update(2, "Hello World".getBytes) assert(unsafeRow.get(2) === "Hello World".getBytes) assert(pool.size === 2) + + // We do not support copy() for UnsafeRows that reference ObjectPools + intercept[UnsupportedOperationException] { + unsafeRow.copy() + } } test("basic conversion with primitive, decimal and array") { @@ -111,12 +131,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) assert(sizeRequired === 8 + (8 * 3)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool) assert(numBytesWritten === sizeRequired) assert(pool.size === 2) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.get(1) === Decimal(1)) assert(unsafeRow.get(2) === Array(2)) @@ -142,11 +164,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(sizeRequired === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow @@ -190,12 +214,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, + sizeRequired, null) assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( - createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, + sizeRequired, null) for (i <- 0 to fieldTypes.length - 1) { assert(createdFromNull.isNullAt(i)) } @@ -233,10 +259,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val pool = new ObjectPool(1) val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, + sizeRequired, pool) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, + sizeRequired, pool) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) From 47ef423f860c3109d50c7e321616b267f4296e34 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 8 Jul 2015 20:29:08 -0700 Subject: [PATCH 206/274] [SPARK-8910] Fix MiMa flaky due to port contention issue Due to the way MiMa works, we currently start a `SQLContext` pretty early on. This causes us to start a `SparkUI` that attempts to bind to port 4040. Because many tests run in parallel on the Jenkins machines, this causes port contention sometimes and fails the MiMa tests. Note that we already disabled the SparkUI for scalatests. However, the MiMa test is run before we even have a chance to load the default scalatest settings, so we need to explicitly disable the UI ourselves. Author: Andrew Or Closes #7300 from andrewor14/mima-flaky and squashes the following commits: b55a547 [Andrew Or] Do not enable SparkUI during tests --- .../scala/org/apache/spark/sql/test/TestSQLContext.scala | 8 ++++---- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 9fa394525d65c..b3a4231da91c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** A SQLContext that can be used for local testing. */ class LocalSQLContext extends SQLContext( - new SparkContext( - "local[2]", - "TestSQLContext", - new SparkConf().set("spark.sql.testkey", "true"))) { + new SparkContext("local[2]", "TestSQLContext", new SparkConf() + .set("spark.sql.testkey", "true") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) { override protected[sql] def createSession(): SQLSession = { new this.SQLSession() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7978fdacaedba..0f217bc66869f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -53,9 +53,10 @@ object TestHive "TestSQLContext", new SparkConf() .set("spark.sql.test", "") - .set( - "spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe"))) + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) /** * A locally running test instance of Spark's Hive execution engine. From aba5784dab24c03ddad89f7a1b5d3d0dc8d109be Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 9 Jul 2015 13:28:17 +0900 Subject: [PATCH 207/274] [SPARK-8937] [TEST] A setting `spark.unsafe.exceptionOnMemoryLeak ` is missing in ScalaTest config. `spark.unsafe.exceptionOnMemoryLeak` is present in the config of surefire. ``` org.apache.maven.plugins maven-surefire-plugin 2.18.1 ... true ... ``` but is absent in the config ScalaTest. Author: Kousuke Saruta Closes #7308 from sarutak/add-setting-for-memory-leak and squashes the following commits: 95644e7 [Kousuke Saruta] Added a setting for memory leak --- pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/pom.xml b/pom.xml index 9cf2471b51304..529e47f8b5253 100644 --- a/pom.xml +++ b/pom.xml @@ -1339,6 +1339,7 @@ false false true + true From 768907eb7b0d3c11a420ef281454e36167011c89 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 8 Jul 2015 22:05:58 -0700 Subject: [PATCH 208/274] [SPARK-8926][SQL] Good errors for ExpectsInputType expressions For example: `cannot resolve 'testfunction(null)' due to data type mismatch: argument 1 is expected to be of type int, however, null is of type datetype.` Author: Michael Armbrust Closes #7303 from marmbrus/expectsTypeErrors and squashes the following commits: c654a0e [Michael Armbrust] fix udts and make errors pretty 137160d [Michael Armbrust] style 5428fda [Michael Armbrust] style 10fac82 [Michael Armbrust] [SPARK-8926][SQL] Good errors for ExpectsInputType expressions --- .../catalyst/analysis/HiveTypeCoercion.scala | 12 +- .../expressions/ExpectsInputTypes.scala | 13 +- .../spark/sql/types/AbstractDataType.scala | 30 +++- .../apache/spark/sql/types/ArrayType.scala | 8 +- .../org/apache/spark/sql/types/DataType.scala | 4 +- .../apache/spark/sql/types/DecimalType.scala | 8 +- .../org/apache/spark/sql/types/MapType.scala | 8 +- .../apache/spark/sql/types/StructType.scala | 8 +- .../spark/sql/types/UserDefinedType.scala | 5 +- .../analysis/AnalysisErrorSuite.scala | 167 ++++++++++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 126 ++----------- .../analysis/HiveTypeCoercionSuite.scala | 8 + .../apache/spark/sql/hive/HiveContext.scala | 2 +- 13 files changed, 256 insertions(+), 143 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 5367b7f3308ee..8cb71995eb818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -702,11 +702,19 @@ object HiveTypeCoercion { @Nullable val ret: Expression = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.isParentOf(inType) => e + case _ if expectedType.isSameType(inType) => e // Cast null type (usually from null literals) into target types case (NullType, target) => Cast(e, target.defaultConcreteType) + // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is + // already a number, leave it as is. + case (_: NumericType, NumericType) => e + + // If the function accepts any numeric type and the input is a string, we follow the hive + // convention and cast that input into a double + case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) + // Implicit cast among numeric types // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. @@ -732,7 +740,7 @@ object HiveTypeCoercion { // First see if we can find our input type in the type collection. If we can, then just // use the current expression; otherwise, find the first one we can implicitly cast. case (_, TypeCollection(types)) => - if (types.exists(_.isParentOf(inType))) { + if (types.exists(_.isSameType(inType))) { e } else { types.flatMap(implicitCast(e, _)).headOption.orNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 916e30154d4f1..986cc09499d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -37,7 +37,16 @@ trait ExpectsInputTypes { self: Expression => def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - // TODO: implement proper type checking. - TypeCheckResult.TypeCheckSuccess + val mismatches = children.zip(inputTypes).zipWithIndex.collect { + case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + s"Argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + + s"however, ${child.prettyString} is of type ${child.dataType.simpleString}." + } + + if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index fb1b47e946214..ad75fa2e31d90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -34,9 +34,16 @@ private[sql] abstract class AbstractDataType { private[sql] def defaultConcreteType: DataType /** - * Returns true if this data type is a parent of the `childCandidate`. + * Returns true if this data type is the same type as `other`. This is different that equality + * as equality will also consider data type parametrization, such as decimal precision. */ - private[sql] def isParentOf(childCandidate: DataType): Boolean + private[sql] def isSameType(other: DataType): Boolean + + /** + * Returns true if `other` is an acceptable input type for a function that expectes this, + * possibly abstract, DataType. + */ + private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) /** Readable string representation for the type. */ private[sql] def simpleString: String @@ -58,11 +65,14 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) require(types.nonEmpty, s"TypeCollection ($types) cannot be empty") - private[sql] override def defaultConcreteType: DataType = types.head.defaultConcreteType + override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType + + override private[sql] def isSameType(other: DataType): Boolean = false - private[sql] override def isParentOf(childCandidate: DataType): Boolean = false + override private[sql] def acceptsType(other: DataType): Boolean = + types.exists(_.isSameType(other)) - private[sql] override def simpleString: String = { + override private[sql] def simpleString: String = { types.map(_.simpleString).mkString("(", " or ", ")") } } @@ -108,7 +118,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType { +private[sql] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -117,6 +127,14 @@ private[sql] object NumericType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] + + override private[sql] def defaultConcreteType: DataType = DoubleType + + override private[sql] def simpleString: String = "numeric" + + override private[sql] def isSameType(other: DataType): Boolean = false + + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 43413ec761e6b..76ca7a84c1d1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -26,13 +26,13 @@ object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) - private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) + override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[ArrayType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[ArrayType] } - private[sql] override def simpleString: String = "array" + override private[sql] def simpleString: String = "array" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index a4c2da8e05f5d..57718228e490f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -76,9 +76,9 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType - private[sql] override def defaultConcreteType: DataType = this + override private[sql] def defaultConcreteType: DataType = this - private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate + override private[sql] def isSameType(other: DataType): Boolean = this == other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 127b16ff85bed..a1cafeab1704d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -84,13 +84,13 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** Extra factory methods and pattern matchers for Decimals */ object DecimalType extends AbstractDataType { - private[sql] override def defaultConcreteType: DataType = Unlimited + override private[sql] def defaultConcreteType: DataType = Unlimited - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[DecimalType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[DecimalType] } - private[sql] override def simpleString: String = "decimal" + override private[sql] def simpleString: String = "decimal" val Unlimited: DecimalType = DecimalType(None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 868dea13d971e..ddead10bc2171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -69,13 +69,13 @@ case class MapType( object MapType extends AbstractDataType { - private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType) + override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[MapType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[MapType] } - private[sql] override def simpleString: String = "map" + override private[sql] def simpleString: String = "map" /** * Construct a [[MapType]] object with the given key type and value type. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e2d3f53f7d978..e0b8ff91786a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -303,13 +303,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru object StructType extends AbstractDataType { - private[sql] override def defaultConcreteType: DataType = new StructType + override private[sql] def defaultConcreteType: DataType = new StructType - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[StructType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[StructType] } - private[sql] override def simpleString: String = "struct" + override private[sql] def simpleString: String = "struct" private[sql] def fromString(raw: String): StructType = DataType.fromString(raw) match { case t: StructType => t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 6b20505c6009a..e47cfb4833bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -77,5 +77,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { * For UDT, asNullable will not change the nullability of its internal sqlType and just returns * itself. */ - private[spark] override def asNullable: UserDefinedType[UserType] = this + override private[spark] def asNullable: UserDefinedType[UserType] = this + + override private[sql] def acceptsType(dataType: DataType) = + this.getClass == dataType.getClass } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala new file mode 100644 index 0000000000000..73236c3acbca2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.{InternalRow, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +case class TestFunction( + children: Seq[Expression], + inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes { + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def dataType: DataType = StringType +} + +case class UnresolvedTestPlan() extends LeafNode { + override lazy val resolved = false + override def output: Seq[Attribute] = Nil +} + +class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { + import AnalysisSuite._ + + def errorTest( + name: String, + plan: LogicalPlan, + errorMessages: Seq[String], + caseSensitive: Boolean = true): Unit = { + test(name) { + val error = intercept[AnalysisException] { + if (caseSensitive) { + caseSensitiveAnalyze(plan) + } else { + caseInsensitiveAnalyze(plan) + } + } + + errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) + } + } + + val dateLit = Literal.create(null, DateType) + + errorTest( + "single invalid type, single arg", + testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), + "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" :: + "null is of type date" ::Nil) + + errorTest( + "single invalid type, second arg", + testRelation.select( + TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), + "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" :: + "null is of type date" ::Nil) + + errorTest( + "multiple invalid type", + testRelation.select( + TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), + "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: + "expected to be of type int" :: "null is of type date" ::Nil) + + errorTest( + "unresolved window function", + testRelation2.select( + WindowExpression( + UnresolvedWindowFunction( + "lead", + UnresolvedAttribute("c") :: Nil), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "lead" :: "window functions currently requires a HiveContext" :: Nil) + + errorTest( + "too many generators", + listRelation.select(Explode('list).as('a), Explode('list).as('b)), + "only one generator" :: "explode" :: Nil) + + errorTest( + "unresolved attributes", + testRelation.select('abcd), + "cannot resolve" :: "abcd" :: Nil) + + errorTest( + "bad casts", + testRelation.select(Literal(1).cast(BinaryType).as('badCast)), + "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + + errorTest( + "non-boolean filters", + testRelation.where(Literal(1)), + "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + + errorTest( + "missing group by", + testRelation2.groupBy('a)('b), + "'b'" :: "group by" :: Nil + ) + + errorTest( + "ambiguous field", + nestedRelation.select($"top.duplicateField"), + "Ambiguous reference to fields" :: "duplicateField" :: Nil, + caseSensitive = false) + + errorTest( + "ambiguous field due to case insensitivity", + nestedRelation.select($"top.differentCase"), + "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, + caseSensitive = false) + + errorTest( + "missing field", + nestedRelation2.select($"top.c"), + "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, + caseSensitive = false) + + errorTest( + "catch all unresolved plan", + UnresolvedTestPlan(), + "unresolved" :: Nil) + + + test("SPARK-6452 regression test") { + // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + val plan = + Aggregate( + Nil, + Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, + LocalRelation( + AttributeReference("a", IntegerType)(exprId = ExprId(2)))) + + assert(plan.resolved) + + val message = intercept[AnalysisException] { + caseSensitiveAnalyze(plan) + }.getMessage + + assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 77ca080f366cd..58df1de983a09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { +object AnalysisSuite { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -61,25 +61,28 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: - StructField("duplicateField", StringType) :: - StructField("differentCase", StringType) :: - StructField("differentcase", StringType) :: Nil + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil ))()) val nestedRelation2 = LocalRelation( AttributeReference("top", StructType( StructField("aField", StringType) :: - StructField("bField", StringType) :: - StructField("cField", StringType) :: Nil + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil ))()) val listRelation = LocalRelation( AttributeReference("list", ArrayType(IntegerType))()) - before { - caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - } + caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) +} + + +class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { + import AnalysisSuite._ test("union project *") { val plan = (1 to 100) @@ -149,91 +152,6 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } - def errorTest( - name: String, - plan: LogicalPlan, - errorMessages: Seq[String], - caseSensitive: Boolean = true): Unit = { - test(name) { - val error = intercept[AnalysisException] { - if (caseSensitive) { - caseSensitiveAnalyze(plan) - } else { - caseInsensitiveAnalyze(plan) - } - } - - errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) - } - } - - errorTest( - "unresolved window function", - testRelation2.select( - WindowExpression( - UnresolvedWindowFunction( - "lead", - UnresolvedAttribute("c") :: Nil), - WindowSpecDefinition( - UnresolvedAttribute("a") :: Nil, - SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, - UnspecifiedFrame)).as('window)), - "lead" :: "window functions currently requires a HiveContext" :: Nil) - - errorTest( - "too many generators", - listRelation.select(Explode('list).as('a), Explode('list).as('b)), - "only one generator" :: "explode" :: Nil) - - errorTest( - "unresolved attributes", - testRelation.select('abcd), - "cannot resolve" :: "abcd" :: Nil) - - errorTest( - "bad casts", - testRelation.select(Literal(1).cast(BinaryType).as('badCast)), - "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) - - errorTest( - "non-boolean filters", - testRelation.where(Literal(1)), - "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) - - errorTest( - "missing group by", - testRelation2.groupBy('a)('b), - "'b'" :: "group by" :: Nil - ) - - errorTest( - "ambiguous field", - nestedRelation.select($"top.duplicateField"), - "Ambiguous reference to fields" :: "duplicateField" :: Nil, - caseSensitive = false) - - errorTest( - "ambiguous field due to case insensitivity", - nestedRelation.select($"top.differentCase"), - "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, - caseSensitive = false) - - errorTest( - "missing field", - nestedRelation2.select($"top.c"), - "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, - caseSensitive = false) - - case class UnresolvedTestPlan() extends LeafNode { - override lazy val resolved = false - override def output: Seq[Attribute] = Nil - } - - errorTest( - "catch all unresolved plan", - UnresolvedTestPlan(), - "unresolved" :: Nil) - test("divide should be casted into fractional types") { val testRelation2 = LocalRelation( @@ -258,22 +176,4 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } - - test("SPARK-6452 regression test") { - // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) - val plan = - Aggregate( - Nil, - Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, - LocalRelation( - AttributeReference("a", IntegerType)(exprId = ExprId(2)))) - - assert(plan.resolved) - - val message = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - }.getMessage - - assert(message.contains("resolved attribute(s) a#1 missing from a#2")) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 93db33d44eb25..6e3aa0eebeb15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -77,6 +77,14 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) + + shouldCast(StringType, NumericType, DoubleType) + + // NumericType should not be changed when function accepts any of them. + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe => + shouldCast(tpe, NumericType, tpe) + } } test("ineligible implicit type cast") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 439d8cab5f257..bbc39b892b79e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -359,7 +359,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { hiveconf.set(key, value) } - private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { setConf(entry.key, entry.stringConverter(value)) } From a240bf3b44b15d0da5182d6ebec281dbdc5439e8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Jul 2015 22:08:50 -0700 Subject: [PATCH 209/274] Closes #7310. From 3dab0da42940a46f0c4aa4853bdb5c64c4cb2613 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Jul 2015 22:09:12 -0700 Subject: [PATCH 210/274] [SPARK-8928] [SQL] Makes CatalystSchemaConverter sticking to 1.4.x- when handling Parquet LISTs in compatible mode This PR is based on #7209 authored by Sephiroth-Lin. Author: Weizhong Lin Closes #7304 from liancheng/spark-8928 and squashes the following commits: 75267fe [Cheng Lian] Makes CatalystSchemaConverter sticking to 1.4.x- when handling LISTs in compatible mode --- .../spark/sql/parquet/CatalystSchemaConverter.scala | 6 ++++-- .../apache/spark/sql/parquet/ParquetSchemaSuite.scala | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index de3a72d8146c5..1ea6926af6d5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -461,7 +461,8 @@ private[parquet] class CatalystSchemaConverter( field.name, Types .buildGroup(REPEATED) - .addField(convertField(StructField("element", elementType, nullable))) + // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) + .addField(convertField(StructField("array_element", elementType, nullable))) .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level @@ -474,7 +475,8 @@ private[parquet] class CatalystSchemaConverter( ConversionPatterns.listType( repetition, field.name, - convertField(StructField("element", elementType, nullable), REPEATED)) + // "array" is the name chosen by parquet-avro (1.7.0 and prior version) + convertField(StructField("array", elementType, nullable), REPEATED)) // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 35d3c33f99a06..fa629392674bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -174,7 +174,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { """ |message root { | optional group _1 (LIST) { - | repeated int32 element; + | repeated int32 array; | } |} """.stripMargin) @@ -198,7 +198,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional group _1 (LIST) { | repeated group bag { - | optional int32 element; + | optional int32 array_element; | } | } |} @@ -267,7 +267,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group element { + | optional group array_element { | required int32 _1; | required double _2; | } @@ -616,7 +616,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional group f1 (LIST) { | repeated group bag { - | optional int32 element; + | optional int32 array_element; | } | } |} @@ -648,7 +648,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { nullable = true))), """message root { | optional group f1 (LIST) { - | repeated int32 element; + | repeated int32 array; | } |} """.stripMargin) From c056484c0741e2a03d4a916538e1b9e3e65e71c3 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Jul 2015 22:14:38 -0700 Subject: [PATCH 211/274] Revert "[SPARK-8928] [SQL] Makes CatalystSchemaConverter sticking to 1.4.x- when handling Parquet LISTs in compatible mode" This reverts commit 3dab0da42940a46f0c4aa4853bdb5c64c4cb2613. --- .../spark/sql/parquet/CatalystSchemaConverter.scala | 6 ++---- .../apache/spark/sql/parquet/ParquetSchemaSuite.scala | 10 +++++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 1ea6926af6d5b..de3a72d8146c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -461,8 +461,7 @@ private[parquet] class CatalystSchemaConverter( field.name, Types .buildGroup(REPEATED) - // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) - .addField(convertField(StructField("array_element", elementType, nullable))) + .addField(convertField(StructField("element", elementType, nullable))) .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level @@ -475,8 +474,7 @@ private[parquet] class CatalystSchemaConverter( ConversionPatterns.listType( repetition, field.name, - // "array" is the name chosen by parquet-avro (1.7.0 and prior version) - convertField(StructField("array", elementType, nullable), REPEATED)) + convertField(StructField("element", elementType, nullable), REPEATED)) // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index fa629392674bd..35d3c33f99a06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -174,7 +174,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { """ |message root { | optional group _1 (LIST) { - | repeated int32 array; + | repeated int32 element; | } |} """.stripMargin) @@ -198,7 +198,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional group _1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 element; | } | } |} @@ -267,7 +267,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group array_element { + | optional group element { | required int32 _1; | required double _2; | } @@ -616,7 +616,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional group f1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 element; | } | } |} @@ -648,7 +648,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { nullable = true))), """message root { | optional group f1 (LIST) { - | repeated int32 array; + | repeated int32 element; | } |} """.stripMargin) From 851e247caad0977cfd4998254d9602624e06539f Mon Sep 17 00:00:00 2001 From: Weizhong Lin Date: Wed, 8 Jul 2015 22:18:39 -0700 Subject: [PATCH 212/274] [SPARK-8928] [SQL] Makes CatalystSchemaConverter sticking to 1.4.x- when handling Parquet LISTs in compatible mode This PR is based on #7209 authored by Sephiroth-Lin. Author: Weizhong Lin Closes #7314 from liancheng/spark-8928 and squashes the following commits: 75267fe [Cheng Lian] Makes CatalystSchemaConverter sticking to 1.4.x- when handling LISTs in compatible mode --- .../spark/sql/parquet/CatalystSchemaConverter.scala | 6 ++++-- .../apache/spark/sql/parquet/ParquetSchemaSuite.scala | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index de3a72d8146c5..1ea6926af6d5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -461,7 +461,8 @@ private[parquet] class CatalystSchemaConverter( field.name, Types .buildGroup(REPEATED) - .addField(convertField(StructField("element", elementType, nullable))) + // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) + .addField(convertField(StructField("array_element", elementType, nullable))) .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level @@ -474,7 +475,8 @@ private[parquet] class CatalystSchemaConverter( ConversionPatterns.listType( repetition, field.name, - convertField(StructField("element", elementType, nullable), REPEATED)) + // "array" is the name chosen by parquet-avro (1.7.0 and prior version) + convertField(StructField("array", elementType, nullable), REPEATED)) // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 35d3c33f99a06..fa629392674bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -174,7 +174,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { """ |message root { | optional group _1 (LIST) { - | repeated int32 element; + | repeated int32 array; | } |} """.stripMargin) @@ -198,7 +198,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional group _1 (LIST) { | repeated group bag { - | optional int32 element; + | optional int32 array_element; | } | } |} @@ -267,7 +267,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group element { + | optional group array_element { | required int32 _1; | required double _2; | } @@ -616,7 +616,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional group f1 (LIST) { | repeated group bag { - | optional int32 element; + | optional int32 array_element; | } | } |} @@ -648,7 +648,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { nullable = true))), """message root { | optional group f1 (LIST) { - | repeated int32 element; + | repeated int32 array; | } |} """.stripMargin) From 09cb0d9c2dcb83818ced22ff9bd6a51688ea7ffe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Jul 2015 00:26:25 -0700 Subject: [PATCH 213/274] [SPARK-8942][SQL] use double not decimal when cast double and float to timestamp Author: Wenchen Fan Closes #7312 from cloud-fan/minor and squashes the following commits: a4589fa [Wenchen Fan] use double not decimal when cast double and float to timestamp --- .../spark/sql/catalyst/expressions/Cast.scala | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 567feca7136f9..7f2383dedc035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -192,23 +192,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Decimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp case DoubleType => - buildCast[Double](_, d => try { - decimalToTimestamp(Decimal(d)) - } catch { - case _: NumberFormatException => null - }) + buildCast[Double](_, d => doubleToTimestamp(d)) // TimestampWritable.floatToTimestamp case FloatType => - buildCast[Float](_, f => try { - decimalToTimestamp(Decimal(f)) - } catch { - case _: NumberFormatException => null - }) + buildCast[Float](_, f => doubleToTimestamp(f.toDouble)) } private[this] def decimalToTimestamp(d: Decimal): Long = { (d.toBigDecimal * 1000000L).longValue() } + private[this] def doubleToTimestamp(d: Double): Any = { + if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong + } // converting milliseconds to us private[this] def longToTimestamp(t: Long): Long = t * 1000L @@ -396,8 +391,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[InternalRow](_, row => { var i = 0 while (i < row.length) { - val v = row(i) - newRow.update(i, if (v == null) null else casts(i)(v)) + newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row(i))) i += 1 } newRow.copy() From f88b12537ee81d914ef7c51a08f80cb28d93c8ed Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 9 Jul 2015 08:16:26 -0700 Subject: [PATCH 214/274] [SPARK-6266] [MLLIB] PySpark SparseVector missing doc for size, indices, values Write missing pydocs in `SparseVector` attributes. Author: lewuathe Closes #7290 from Lewuathe/SPARK-6266 and squashes the following commits: 51d9895 [lewuathe] Update docs 0480d35 [lewuathe] Merge branch 'master' into SPARK-6266 ba42cf3 [lewuathe] [SPARK-6266] PySpark SparseVector missing doc for size, indices, values --- python/pyspark/mllib/linalg.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 51ac198305711..040886f71775b 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -445,8 +445,10 @@ def __init__(self, size, *args): values (sorted by index). :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, - or two sorted lists containing indices and values. + :param args: Active entries, as a dictionary {index: value, ...}, + a list of tuples [(index, value), ...], or a list of strictly i + ncreasing indices and a list of corresponding values [index, ...], + [value, ...]. Inactive entries are treated as zeros. >>> SparseVector(4, {1: 1.0, 3: 5.5}) SparseVector(4, {1: 1.0, 3: 5.5}) @@ -456,6 +458,7 @@ def __init__(self, size, *args): SparseVector(4, {1: 1.0, 3: 5.5}) """ self.size = int(size) + """ Size of the vector. """ assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" if len(args) == 1: pairs = args[0] @@ -463,7 +466,9 @@ def __init__(self, size, *args): pairs = pairs.items() pairs = sorted(pairs) self.indices = np.array([p[0] for p in pairs], dtype=np.int32) + """ A list of indices corresponding to active entries. """ self.values = np.array([p[1] for p in pairs], dtype=np.float64) + """ A list of values corresponding to active entries. """ else: if isinstance(args[0], bytes): assert isinstance(args[1], bytes), "values should be string too" From 23448a9e988a1b92bd05ee8c6c1a096c83375a12 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 09:20:16 -0700 Subject: [PATCH 215/274] [SPARK-8931] [SQL] Fallback to interpreted evaluation if failed to compile in codegen Exception will not be catched during tests. cc marmbrus rxin Author: Davies Liu Closes #7309 from davies/fallback and squashes the following commits: 969a612 [Davies Liu] throw exception during tests f844f77 [Davies Liu] fallback a3091bc [Davies Liu] Merge branch 'master' of github.com:apache/spark into fallback 364a0d6 [Davies Liu] fallback to interpret mode if failed to compile --- .../spark/sql/execution/SparkPlan.scala | 51 +++++++++++++++++-- .../apache/spark/sql/sources/commands.scala | 13 ++++- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ca53186383237..4d7d8626a0ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -153,12 +153,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ buf.toArray.map(converter(_).asInstanceOf[Row]) } + private[this] def isTesting: Boolean = sys.props.contains("spark.testing") + protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { - GenerateProjection.generate(expressions, inputSchema) + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } + } } else { new InterpretedProjection(expressions, inputSchema) } @@ -170,17 +182,36 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled) { - GenerateMutableProjection.generate(expressions, inputSchema) + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } } else { () => new InterpretedMutableProjection(expressions, inputSchema) } } - protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { if (codegenEnabled) { - GeneratePredicate.generate(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } + } } else { InterpretedPredicate.create(expression, inputSchema) } @@ -190,7 +221,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { if (codegenEnabled) { - GenerateOrdering.generate(order, inputSchema) + try { + GenerateOrdering.generate(order, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate ordering, fallback to interpreted", e) + new RowOrdering(order, inputSchema) + } + } } else { new RowOrdering(order, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index ecbc889770625..9189d176111d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -276,7 +276,18 @@ private[sql] case class InsertIntoHadoopFsRelation( log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { - GenerateProjection.generate(expressions, inputSchema) + + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (sys.props.contains("spark.testing")) { + throw e + } else { + log.error("failed to generate projection, fallback to interpreted", e) + new InterpretedProjection(expressions, inputSchema) + } + } } else { new InterpretedProjection(expressions, inputSchema) } From a1964e9d902bb31f001893da8bc81f6dce08c908 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Thu, 9 Jul 2015 09:22:24 -0700 Subject: [PATCH 216/274] [SPARK-8830] [SQL] native levenshtein distance Jira: https://issues.apache.org/jira/browse/SPARK-8830 rxin and HuJiayin can you have a look on it. Author: Tarek Auel Closes #7236 from tarekauel/native-levenshtein-distance and squashes the following commits: ee4c4de [Tarek Auel] [SPARK-8830] implemented improvement proposals c252e71 [Tarek Auel] [SPARK-8830] removed chartAt; use unsafe method for byte array comparison ddf2222 [Tarek Auel] Merge branch 'master' into native-levenshtein-distance 179920a [Tarek Auel] [SPARK-8830] added description 5e9ed54 [Tarek Auel] [SPARK-8830] removed StringUtils import dce4308 [Tarek Auel] [SPARK-8830] native levenshtein distance --- .../expressions/stringOperations.scala | 9 ++- .../expressions/StringFunctionsSuite.scala | 5 ++ .../apache/spark/unsafe/types/UTF8String.java | 66 ++++++++++++++++++- .../spark/unsafe/types/UTF8StringSuite.java | 24 +++++++ 4 files changed, 97 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 47fc7cdaa826c..57f436485becf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -284,13 +284,12 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = IntegerType - protected override def nullSafeEval(input1: Any, input2: Any): Any = - StringUtils.getLevenshteinDistance(input1.toString, input2.toString) + protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = + leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val stringUtils = classOf[StringUtils].getName - defineCodeGen(ctx, ev, (left, right) => - s"$stringUtils.getLevenshteinDistance($left.toString(), $right.toString())") + nullSafeCodeGen(ctx, ev, (left, right) => + s"${ev.primitive} = $left.levenshteinDistance($right);") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 1efbe1a245e83..69bef1c63e9dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -282,5 +282,10 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0) checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3) checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + checkEvaluation(Levenshtein(Literal("千世"), Literal("fog")), 3) + checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4) + // scalastyle:on } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index d2a25096a5e7a..847d80ad583f6 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -99,8 +99,6 @@ public int numBytes() { /** * Returns the number of code points in it. - * - * This is only used by Substring() when `start` is negative. */ public int numChars() { int len = 0; @@ -254,6 +252,70 @@ public boolean equals(final Object other) { } } + /** + * Levenshtein distance is a metric for measuring the distance of two strings. The distance is + * defined by the minimum number of single-character edits (i.e. insertions, deletions or + * substitutions) that are required to change one of the strings into the other. + */ + public int levenshteinDistance(UTF8String other) { + // Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance + + int n = numChars(); + int m = other.numChars(); + + if (n == 0) { + return m; + } else if (m == 0) { + return n; + } + + UTF8String s, t; + + if (n <= m) { + s = this; + t = other; + } else { + s = other; + t = this; + int swap; + swap = n; + n = m; + m = swap; + } + + int p[] = new int[n + 1]; + int d[] = new int[n + 1]; + int swap[]; + + int i, i_bytes, j, j_bytes, num_bytes_j, cost; + + for (i = 0; i <= n; i++) { + p[i] = i; + } + + for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) { + num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes)); + d[0] = j + 1; + + for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) { + if (s.getByte(i_bytes) != t.getByte(j_bytes) || + num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { + cost = 1; + } else { + cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, + s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + } + d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); + } + + swap = p; + p = d; + d = swap; + } + + return p[n]; + } + @Override public int hashCode() { int result = 1; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 8ec69ebac8b37..fb463ba17f50b 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -128,4 +128,28 @@ public void substring() { assertEquals(fromString("数据砖头").substring(3, 5), fromString("头")); assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷")); } + + @Test + public void levenshteinDistance() { + assertEquals( + UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0); + assertEquals( + UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1); + assertEquals( + UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7); + assertEquals( + UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1); + assertEquals( + UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3); + assertEquals( + UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7); + assertEquals( + UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7); + assertEquals( + UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8); + assertEquals( + UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1); + assertEquals( + UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4); + } } From 59cc38944fe5c1dffc6551775bd939e2ac66c65e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Jul 2015 09:57:12 -0700 Subject: [PATCH 217/274] [SPARK-8940] [SPARKR] Don't overwrite given schema in createDataFrame JIRA: https://issues.apache.org/jira/browse/SPARK-8940 The given `schema` parameter will be overwritten in `createDataFrame` now. If it is not null, we shouldn't overwrite it. Author: Liang-Chi Hsieh Closes #7311 from viirya/df_not_overwrite_schema and squashes the following commits: 2385139 [Liang-Chi Hsieh] Don't overwrite given schema if it is not null. --- R/pkg/R/SQLContext.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9a743a3411533..30978bb50d339 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -86,7 +86,9 @@ infer_type <- function(x) { createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD - schema <- names(data) + if (is.null(schema)) { + schema <- names(data) + } n <- nrow(data) m <- ncol(data) # get rid of factor type From e204d22bb70f28b1cc090ab60f12078479be4ae0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 9 Jul 2015 10:01:01 -0700 Subject: [PATCH 218/274] [SPARK-8948][SQL] Remove ExtractValueWithOrdinal abstract class Also added more documentation for the file. Author: Reynold Xin Closes #7316 from rxin/extract-value and squashes the following commits: 069cb7e [Reynold Xin] Removed ExtractValueWithOrdinal. 621b705 [Reynold Xin] Reverted a line. 11ebd6c [Reynold Xin] [Minor][SQL] Improve documentation for complex type extractors. --- ...alue.scala => complexTypeExtractors.scala} | 54 ++++++++++++------- 1 file changed, 34 insertions(+), 20 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{ExtractValue.scala => complexTypeExtractors.scala} (86%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala similarity index 86% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 2b25ba03579ec..73cc930c45832 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines all the expressions to extract values out of complex types. +// For example, getting a field out of an array, map, or struct. +//////////////////////////////////////////////////////////////////////////////////////////////////// + object ExtractValue { /** @@ -73,11 +78,10 @@ object ExtractValue { } } - def unapply(g: ExtractValue): Option[(Expression, Expression)] = { - g match { - case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) - case s: ExtractValueWithStruct => Some((s.child, null)) - } + def unapply(g: ExtractValue): Option[(Expression, Expression)] = g match { + case o: GetArrayItem => Some((o.child, o.ordinal)) + case o: GetMapValue => Some((o.child, o.key)) + case s: ExtractValueWithStruct => Some((s.child, null)) } /** @@ -117,6 +121,8 @@ abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue /** * Returns the value of fields in the Struct `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) extends ExtractValueWithStruct { @@ -142,6 +148,8 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) /** * Returns the array of value of fields in the Array of Struct `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetArrayStructFields( child: Expression, @@ -178,25 +186,21 @@ case class GetArrayStructFields( } } -abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue { - self: Product => +/** + * Returns the field at `ordinal` in the Array `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. + */ +case class GetArrayItem(child: Expression, ordinal: Expression) + extends BinaryExpression with ExtractValue { - def ordinal: Expression - def child: Expression + override def toString: String = s"$child[$ordinal]" override def left: Expression = child override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true - override def toString: String = s"$child[$ordinal]" -} - -/** - * Returns the field at `ordinal` in the Array `child` - */ -case class GetArrayItem(child: Expression, ordinal: Expression) - extends ExtractValueWithOrdinal { override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType @@ -227,10 +231,20 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `ordinal` in Map `child` + * Returns the value of key `ordinal` in Map `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ -case class GetMapValue(child: Expression, ordinal: Expression) - extends ExtractValueWithOrdinal { +case class GetMapValue(child: Expression, key: Expression) + extends BinaryExpression with ExtractValue { + + override def toString: String = s"$child[$key]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType From a870a82fb6f57bb63bd6f1e95da944a30f67519a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 9 Jul 2015 10:01:33 -0700 Subject: [PATCH 219/274] [SPARK-8926][SQL] Code review followup. I merged https://github.com/apache/spark/pull/7303 so it unblocks another PR. This addresses my own code review comment for that PR. Author: Reynold Xin Closes #7313 from rxin/adt and squashes the following commits: 7ade82b [Reynold Xin] Fixed unit tests. f8d5533 [Reynold Xin] [SPARK-8926][SQL] Code review followup. --- .../catalyst/expressions/ExpectsInputTypes.scala | 4 ++-- .../spark/sql/types/AbstractDataType.scala | 16 ++++++++++++++++ .../catalyst/analysis/AnalysisErrorSuite.scala | 8 ++++---- .../analysis/HiveTypeCoercionSuite.scala | 1 + 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 986cc09499d1f..3eb0eb195c80d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -39,8 +39,8 @@ trait ExpectsInputTypes { self: Expression => override def checkInputDataTypes(): TypeCheckResult = { val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => - s"Argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, ${child.prettyString} is of type ${child.dataType.simpleString}." + s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + + s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." } if (mismatches.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index ad75fa2e31d90..32f87440b4e37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -36,12 +36,28 @@ private[sql] abstract class AbstractDataType { /** * Returns true if this data type is the same type as `other`. This is different that equality * as equality will also consider data type parametrization, such as decimal precision. + * + * {{{ + * // this should return true + * DecimalType.isSameType(DecimalType(10, 2)) + * + * // this should return false + * NumericType.isSameType(DecimalType(10, 2)) + * }}} */ private[sql] def isSameType(other: DataType): Boolean /** * Returns true if `other` is an acceptable input type for a function that expectes this, * possibly abstract, DataType. + * + * {{{ + * // this should return true + * DecimalType.isSameType(DecimalType(10, 2)) + * + * // this should return true as well + * NumericType.acceptsType(DecimalType(10, 2)) + * }}} */ private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 73236c3acbca2..9d0c69a2451d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -58,7 +58,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { } } - errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) + errorMessages.foreach(m => assert(error.getMessage.toLowerCase.contains(m.toLowerCase))) } } @@ -68,21 +68,21 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" :: - "null is of type date" ::Nil) + "'null' is of type date" ::Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" :: - "null is of type date" ::Nil) + "'null' is of type date" ::Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "expected to be of type int" :: "null is of type date" ::Nil) + "expected to be of type int" :: "'null' is of type date" ::Nil) errorTest( "unresolved window function", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 6e3aa0eebeb15..acb9a433de903 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -79,6 +79,7 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) shouldCast(StringType, NumericType, DoubleType) + shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType) // NumericType should not be changed when function accepts any of them. Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, From f6c0bd5c3755b2f9bab633a5d478240fdaf1c593 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Jul 2015 10:04:42 -0700 Subject: [PATCH 220/274] [SPARK-8938][SQL] Implement toString for Interval data type Author: Wenchen Fan Closes #7315 from cloud-fan/toString and squashes the following commits: 4fc8d80 [Wenchen Fan] Implement toString for Interval data type --- .../apache/spark/sql/catalyst/SqlParser.scala | 24 ++++++-- .../apache/spark/unsafe/types/Interval.java | 42 +++++++++++++ .../spark/unsafe/types/IntervalSuite.java | 59 +++++++++++++++++++ 3 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index dedd8c8fa3620..d4ef04c2294a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -353,22 +353,34 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { integral <~ intervalUnit("microsecond") ^^ { case num => num.toLong } protected lazy val millisecond: Parser[Long] = - integral <~ intervalUnit("millisecond") ^^ { case num => num.toLong * 1000 } + integral <~ intervalUnit("millisecond") ^^ { + case num => num.toLong * Interval.MICROS_PER_MILLI + } protected lazy val second: Parser[Long] = - integral <~ intervalUnit("second") ^^ { case num => num.toLong * 1000 * 1000 } + integral <~ intervalUnit("second") ^^ { + case num => num.toLong * Interval.MICROS_PER_SECOND + } protected lazy val minute: Parser[Long] = - integral <~ intervalUnit("minute") ^^ { case num => num.toLong * 1000 * 1000 * 60 } + integral <~ intervalUnit("minute") ^^ { + case num => num.toLong * Interval.MICROS_PER_MINUTE + } protected lazy val hour: Parser[Long] = - integral <~ intervalUnit("hour") ^^ { case num => num.toLong * 1000 * 1000 * 3600 } + integral <~ intervalUnit("hour") ^^ { + case num => num.toLong * Interval.MICROS_PER_HOUR + } protected lazy val day: Parser[Long] = - integral <~ intervalUnit("day") ^^ { case num => num.toLong * 1000 * 1000 * 3600 * 24 } + integral <~ intervalUnit("day") ^^ { + case num => num.toLong * Interval.MICROS_PER_DAY + } protected lazy val week: Parser[Long] = - integral <~ intervalUnit("week") ^^ { case num => num.toLong * 1000 * 1000 * 3600 * 24 * 7 } + integral <~ intervalUnit("week") ^^ { + case num => num.toLong * Interval.MICROS_PER_WEEK + } protected lazy val intervalLiteral: Parser[Literal] = INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index 3eb67ede062d9..0af982d4844c2 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -23,6 +23,13 @@ * The internal representation of interval type. */ public final class Interval implements Serializable { + public static final long MICROS_PER_MILLI = 1000L; + public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000; + public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60; + public static final long MICROS_PER_HOUR = MICROS_PER_MINUTE * 60; + public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; + public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; + public final int months; public final long microseconds; @@ -44,4 +51,39 @@ public boolean equals(Object other) { public int hashCode() { return 31 * months + (int) microseconds; } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("interval"); + + if (months != 0) { + appendUnit(sb, months / 12, "year"); + appendUnit(sb, months % 12, "month"); + } + + if (microseconds != 0) { + long rest = microseconds; + appendUnit(sb, rest / MICROS_PER_WEEK, "week"); + rest %= MICROS_PER_WEEK; + appendUnit(sb, rest / MICROS_PER_DAY, "day"); + rest %= MICROS_PER_DAY; + appendUnit(sb, rest / MICROS_PER_HOUR, "hour"); + rest %= MICROS_PER_HOUR; + appendUnit(sb, rest / MICROS_PER_MINUTE, "minute"); + rest %= MICROS_PER_MINUTE; + appendUnit(sb, rest / MICROS_PER_SECOND, "second"); + rest %= MICROS_PER_SECOND; + appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond"); + rest %= MICROS_PER_MILLI; + appendUnit(sb, rest, "microsecond"); + } + + return sb.toString(); + } + + private void appendUnit(StringBuilder sb, long value, String unit) { + if (value != 0) { + sb.append(" " + value + " " + unit + "s"); + } + } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java new file mode 100644 index 0000000000000..0f4f38b2b03be --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -0,0 +1,59 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import org.junit.Test; + +import static junit.framework.Assert.*; +import static org.apache.spark.unsafe.types.Interval.*; + +public class IntervalSuite { + + @Test + public void equalsTest() { + Interval i1 = new Interval(3, 123); + Interval i2 = new Interval(3, 321); + Interval i3 = new Interval(1, 123); + Interval i4 = new Interval(3, 123); + + assertNotSame(i1, i2); + assertNotSame(i1, i3); + assertNotSame(i2, i3); + assertEquals(i1, i4); + } + + @Test + public void toStringTest() { + Interval i; + + i = new Interval(34, 0); + assertEquals(i.toString(), "interval 2 years 10 months"); + + i = new Interval(-34, 0); + assertEquals(i.toString(), "interval -2 years -10 months"); + + i = new Interval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); + + i = new Interval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); + assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); + + i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); + } +} From c59e268d17cf10e46dbdbe760e2a7580a6364692 Mon Sep 17 00:00:00 2001 From: JPark Date: Thu, 9 Jul 2015 10:23:36 -0700 Subject: [PATCH 221/274] [SPARK-8863] [EC2] Check aws access key from aws credentials if there is no boto config 'spark_ec2.py' use boto to control ec2. And boto can support '~/.aws/credentials' which is AWS CLI default configuration file. We can check this information from ref of boto. "A boto config file is a text file formatted like an .ini configuration file that specifies values for options that control the behavior of the boto library. In Unix/Linux systems, on startup, the boto library looks for configuration files in the following locations and in the following order: /etc/boto.cfg - for site-wide settings that all users on this machine will use (if profile is given) ~/.aws/credentials - for credentials shared between SDKs (if profile is given) ~/.boto - for user-specific settings ~/.aws/credentials - for credentials shared between SDKs ~/.boto - for user-specific settings" * ref of boto: http://boto.readthedocs.org/en/latest/boto_config_tut.html * ref of aws cli : http://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html However 'spark_ec2.py' only check boto config & environment variable even if there is '~/.aws/credentials', and 'spark_ec2.py' is terminated. So I changed to check '~/.aws/credentials'. cc rxin Jira : https://issues.apache.org/jira/browse/SPARK-8863 Author: JPark Closes #7252 from JuhongPark/master and squashes the following commits: 23c5792 [JPark] Check aws access key from aws credentials if there is no boto config --- ec2/spark_ec2.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index dd0c12d25980b..ae4f2ecc5bde7 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -325,14 +325,16 @@ def parse_args(): home_dir = os.getenv('HOME') if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) + # If there is no boto config, check aws credentials + if not os.path.isfile(home_dir + '/.aws/credentials'): + if os.getenv('AWS_ACCESS_KEY_ID') is None: + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) + sys.exit(1) + if os.getenv('AWS_SECRET_ACCESS_KEY') is None: + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) + sys.exit(1) return (opts, action, cluster_name) From 0cd84c86cac68600a74d84e50ad40c0c8b84822a Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 9 Jul 2015 10:26:38 -0700 Subject: [PATCH 222/274] [SPARK-8703] [ML] Add CountVectorizer as a ml transformer to convert document to words count vector jira: https://issues.apache.org/jira/browse/SPARK-8703 Converts a text document to a sparse vector of token counts. I can further add an estimator to extract vocabulary from corpus if that's appropriate. Author: Yuhao Yang Closes #7084 from hhbyyh/countVectorization and squashes the following commits: 5f3f655 [Yuhao Yang] text change 24728e4 [Yuhao Yang] style improvement 576728a [Yuhao Yang] rename to model and some fix 1deca28 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into countVectorization 99b0c14 [Yuhao Yang] undo extension from HashingTF 12c2dc8 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into countVectorization 7ee1c31 [Yuhao Yang] extends HashingTF 809fb59 [Yuhao Yang] minor fix for ut 7c61fb3 [Yuhao Yang] add countVectorizer --- .../ml/feature/CountVectorizerModel.scala | 82 +++++++++++++++++++ .../ml/feature/CountVectorizorSuite.scala | 73 +++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala new file mode 100644 index 0000000000000..6b77de89a0330 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector} +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} + +/** + * :: Experimental :: + * Converts a text document to a sparse vector of token counts. + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. + */ +@Experimental +class CountVectorizerModel (override val uid: String, val vocabulary: Array[String]) + extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] { + + def this(vocabulary: Array[String]) = + this(Identifiable.randomUID("cntVec"), vocabulary) + + /** + * Corpus-specific filter to ignore scarce words in a document. For each document, terms with + * frequency (count) less than the given threshold are ignored. + * Default: 1 + * @group param + */ + val minTermFreq: IntParam = new IntParam(this, "minTermFreq", + "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + + "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1)) + + /** @group setParam */ + def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) + + /** @group getParam */ + def getMinTermFreq: Int = $(minTermFreq) + + setDefault(minTermFreq -> 1) + + override protected def createTransformFunc: Seq[String] => Vector = { + val dict = vocabulary.zipWithIndex.toMap + document => + val termCounts = mutable.HashMap.empty[Int, Double] + document.foreach { term => + dict.get(term) match { + case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) + case None => // ignore terms not in the vocabulary + } + } + Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + } + + override protected def outputDataType: DataType = new VectorUDT() + + override def copy(extra: ParamMap): CountVectorizerModel = { + val copied = new CountVectorizerModel(uid, vocabulary) + copyValues(copied, extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala new file mode 100644 index 0000000000000..e90d9d4ef21ff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) + } + + test("CountVectorizerModel common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, "a b c d".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), + (1, "a b b c d a".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), + (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))), + (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string + (4, "a notInDict d".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary + )).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + val output = cv.transform(df).collect() + output.foreach { p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizerModel with minTermFreq") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))), + (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())), + (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTermFreq(3) + val output = cv.transform(df).collect() + output.foreach { p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } + } +} + + From 0b0b9ceaf73de472198c9804fb7ae61fa2a2e097 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 9 Jul 2015 11:11:34 -0700 Subject: [PATCH 223/274] [SPARK-8247] [SPARK-8249] [SPARK-8252] [SPARK-8254] [SPARK-8257] [SPARK-8258] [SPARK-8259] [SPARK-8261] [SPARK-8262] [SPARK-8253] [SPARK-8260] [SPARK-8267] [SQL] Add String Expressions Author: Cheng Hao Closes #6762 from chenghao-intel/str_funcs and squashes the following commits: b09a909 [Cheng Hao] update the code as feedback 7ebbf4c [Cheng Hao] Add more string expressions --- .../catalyst/analysis/FunctionRegistry.scala | 12 + .../expressions/stringOperations.scala | 306 ++++++++++++++- .../expressions/StringFunctionsSuite.scala | 138 +++++++ .../org/apache/spark/sql/functions.scala | 353 ++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 132 ++++++- .../apache/spark/unsafe/types/UTF8String.java | 191 ++++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 94 ++++- 7 files changed, 1202 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5c25181e1cf50..f62d79f8cea6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -147,12 +147,24 @@ object FunctionRegistry { expression[Base64]("base64"), expression[Encode]("encode"), expression[Decode]("decode"), + expression[StringInstr]("instr"), expression[Lower]("lcase"), expression[Lower]("lower"), expression[StringLength]("length"), expression[Levenshtein]("levenshtein"), + expression[StringLocate]("locate"), + expression[StringLPad]("lpad"), + expression[StringTrimLeft]("ltrim"), + expression[StringFormat]("printf"), + expression[StringRPad]("rpad"), + expression[StringRepeat]("repeat"), + expression[StringReverse]("reverse"), + expression[StringTrimRight]("rtrim"), + expression[StringSpace]("space"), + expression[StringSplit]("split"), expression[Substring]("substr"), expression[Substring]("substring"), + expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), expression[Unhex]("unhex"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 57f436485becf..f64899c1ed84c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale import java.util.regex.Pattern import org.apache.commons.lang3.StringUtils @@ -104,7 +105,7 @@ case class RLike(left: Expression, right: Expression) override def toString: String = s"$left RLIKE $right" } -trait CaseConversionExpression extends ExpectsInputTypes { +trait String2StringExpression extends ExpectsInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -119,7 +120,7 @@ trait CaseConversionExpression extends ExpectsInputTypes { /** * A function that converts the characters of a string to uppercase. */ -case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { +case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -131,7 +132,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE /** * A function that converts the characters of a string to lowercase. */ -case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { +case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -187,6 +188,301 @@ case class EndsWith(left: Expression, right: Expression) } } +/** + * A function that trim the spaces from both ends for the specified string. + */ +case class StringTrim(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trim() + + override def prettyName: String = "trim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trim()") + } +} + +/** + * A function that trim the spaces from left end for given string. + */ +case class StringTrimLeft(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trimLeft() + + override def prettyName: String = "ltrim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + } +} + +/** + * A function that trim the spaces from right end for given string. + */ +case class StringTrimRight(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trimRight() + + override def prettyName: String = "rtrim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trimRight()") + } +} + +/** + * A function that returns the position of the first occurrence of substr in the given string. + * Returns null if either of the arguments are null and + * returns 0 if substr could not be found in str. + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. + */ +case class StringInstr(str: Expression, substr: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = substr + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, sub: Any): Any = { + string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1 + } + + override def prettyName: String = "instr" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (l, r) => + s"($l).indexOf($r, 0) + 1") + } +} + +/** + * A function that returns the position of the first occurrence of substr + * in given string after position pos. + */ +case class StringLocate(substr: Expression, str: Expression, start: Expression) + extends Expression with ExpectsInputTypes { + + def this(substr: Expression, str: Expression) = { + this(substr, str, Literal(0)) + } + + override def children: Seq[Expression] = substr :: str :: start :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = substr.nullable || str.nullable + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + + override def eval(input: InternalRow): Any = { + val s = start.eval(input) + if (s == null) { + // if the start position is null, we need to return 0, (conform to Hive) + 0 + } else { + val r = substr.eval(input) + if (r == null) { + null + } else { + val l = str.eval(input) + if (l == null) { + null + } else { + l.asInstanceOf[UTF8String].indexOf( + r.asInstanceOf[UTF8String], + s.asInstanceOf[Int]) + 1 + } + } + } + } + + override def prettyName: String = "locate" +} + +/** + * Returns str, left-padded with pad to a length of len. + */ +case class StringLPad(str: Expression, len: Expression, pad: Expression) + extends Expression with ExpectsInputTypes { + + override def children: Seq[Expression] = str :: len :: pad :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children.exists(_.nullable) + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + + override def eval(input: InternalRow): Any = { + val s = str.eval(input) + if (s == null) { + null + } else { + val l = len.eval(input) + if (l == null) { + null + } else { + val p = pad.eval(input) + if (p == null) { + null + } else { + val len = l.asInstanceOf[Int] + val str = s.asInstanceOf[UTF8String] + val pad = p.asInstanceOf[UTF8String] + + str.lpad(len, pad) + } + } + } + } + + override def prettyName: String = "lpad" +} + +/** + * Returns str, right-padded with pad to a length of len. + */ +case class StringRPad(str: Expression, len: Expression, pad: Expression) + extends Expression with ExpectsInputTypes { + + override def children: Seq[Expression] = str :: len :: pad :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children.exists(_.nullable) + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + + override def eval(input: InternalRow): Any = { + val s = str.eval(input) + if (s == null) { + null + } else { + val l = len.eval(input) + if (l == null) { + null + } else { + val p = pad.eval(input) + if (p == null) { + null + } else { + val len = l.asInstanceOf[Int] + val str = s.asInstanceOf[UTF8String] + val pad = p.asInstanceOf[UTF8String] + + str.rpad(len, pad) + } + } + } + } + + override def prettyName: String = "rpad" +} + +/** + * Returns the input formatted according do printf-style format strings + */ +case class StringFormat(children: Expression*) extends Expression { + + require(children.length >=1, "printf() should take at least 1 argument") + + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children(0).nullable + override def dataType: DataType = StringType + private def format: Expression = children(0) + private def args: Seq[Expression] = children.tail + + override def eval(input: InternalRow): Any = { + val pattern = format.eval(input) + if (pattern == null) { + null + } else { + val sb = new StringBuffer() + val formatter = new java.util.Formatter(sb, Locale.US) + + val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) + formatter.format(pattern.asInstanceOf[UTF8String].toString(), arglist: _*) + + UTF8String.fromString(sb.toString) + } + } + + override def prettyName: String = "printf" +} + +/** + * Returns the string which repeat the given string value n times. + */ +case class StringRepeat(str: Expression, times: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = times + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType) + + override def nullSafeEval(string: Any, n: Any): Any = { + string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) + } + + override def prettyName: String = "repeat" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") + } +} + +/** + * Returns the reversed given string. + */ +case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { + override def convert(v: UTF8String): UTF8String = v.reverse() + + override def prettyName: String = "reverse" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).reverse()") + } +} + +/** + * Returns a n spaces string. + */ +case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(IntegerType) + + override def nullSafeEval(s: Any): Any = { + val length = s.asInstanceOf[Integer] + + val spaces = new Array[Byte](if (length < 0) 0 else length) + java.util.Arrays.fill(spaces, ' '.asInstanceOf[Byte]) + UTF8String.fromBytes(spaces) + } + + override def prettyName: String = "space" +} + +/** + * Splits str around pat (pattern is a regular expression). + */ +case class StringSplit(str: Expression, pattern: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = pattern + override def dataType: DataType = ArrayType(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, regex: Any): Any = { + val splits = + string.asInstanceOf[UTF8String].toString.split(regex.asInstanceOf[UTF8String].toString, -1) + splits.toSeq.map(UTF8String.fromString) + } + + override def prettyName: String = "split" +} + /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. @@ -199,8 +495,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - - override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def nullable: Boolean = str.nullable || pos.nullable || len.nullable override def dataType: DataType = { if (!resolved) { @@ -373,4 +668,3 @@ case class Encode(value: Expression, charset: Expression) } } - diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 69bef1c63e9dc..b19f4ee37a109 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -288,4 +288,142 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4) // scalastyle:on } + + test("TRIM/LTRIM/RTRIM") { + val s = 'a.string.at(0) + checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) + checkEvaluation(StringTrim(s), "abdef", create_row(" abdef ")) + + checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef ")) + + checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef ")) + checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef ")) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrimRight(s), " 花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) + checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + // scalastyle:on + } + + test("FORMAT") { + val f = 'f.string.at(0) + val d1 = 'd.int.at(1) + val s1 = 's.int.at(2) + + val row1 = create_row("aa%d%s", 12, "cc") + val row2 = create_row(null, 12, "cc") + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + + checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) + checkEvaluation(StringFormat(f, d1, s1), null, row2) + } + + test("INSTR") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val s3 = 'c.string.at(2) + val row1 = create_row("aaads", "aa", "zz") + + checkEvaluation(StringInstr(Literal("aaads"), Literal("aa")), 1, row1) + checkEvaluation(StringInstr(Literal("aaads"), Literal("de")), 0, row1) + checkEvaluation(StringInstr(Literal.create(null, StringType), Literal("de")), null, row1) + checkEvaluation(StringInstr(Literal("aaads"), Literal.create(null, StringType)), null, row1) + + checkEvaluation(StringInstr(s1, s2), 1, row1) + checkEvaluation(StringInstr(s1, s3), 0, row1) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringInstr(s1, s2), 3, create_row("花花世界", "世界")) + checkEvaluation(StringInstr(s1, s2), 1, create_row("花花世界", "花")) + checkEvaluation(StringInstr(s1, s2), 0, create_row("花花世界", "小")) + // scalastyle:on + } + + test("LOCATE") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val s3 = 'c.string.at(2) + val s4 = 'd.int.at(3) + val row1 = create_row("aaads", "aa", "zz", 1) + + checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1) + checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1) + checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1) + + checkEvaluation(new StringLocate(s2, s1), 1, row1) + checkEvaluation(StringLocate(s2, s1, s4), 2, row1) + checkEvaluation(new StringLocate(s3, s1), 0, row1) + checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1) + } + + test("LPAD/RPAD") { + val s1 = 'a.string.at(0) + val s2 = 'b.int.at(1) + val s3 = 'c.string.at(2) + val row1 = create_row("hi", 5, "??") + val row2 = create_row("hi", 1, "?") + val row3 = create_row(null, 1, "?") + + checkEvaluation(StringLPad(Literal("hi"), Literal(5), Literal("??")), "???hi", row1) + checkEvaluation(StringLPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) + checkEvaluation(StringLPad(s1, s2, s3), "???hi", row1) + checkEvaluation(StringLPad(s1, s2, s3), "h", row2) + checkEvaluation(StringLPad(s1, s2, s3), null, row3) + + checkEvaluation(StringRPad(Literal("hi"), Literal(5), Literal("??")), "hi???", row1) + checkEvaluation(StringRPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) + checkEvaluation(StringRPad(s1, s2, s3), "hi???", row1) + checkEvaluation(StringRPad(s1, s2, s3), "h", row2) + checkEvaluation(StringRPad(s1, s2, s3), null, row3) + } + + test("REPEAT") { + val s1 = 'a.string.at(0) + val s2 = 'b.int.at(1) + val row1 = create_row("hi", 2) + val row2 = create_row(null, 1) + + checkEvaluation(StringRepeat(Literal("hi"), Literal(2)), "hihi", row1) + checkEvaluation(StringRepeat(Literal("hi"), Literal(-1)), "", row1) + checkEvaluation(StringRepeat(s1, s2), "hihi", row1) + checkEvaluation(StringRepeat(s1, s2), null, row2) + } + + test("REVERSE") { + val s = 'a.string.at(0) + val row1 = create_row("abccc") + checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) + checkEvaluation(StringReverse(s), "cccba", row1) + } + + test("SPACE") { + val s1 = 'b.int.at(0) + val row1 = create_row(2) + val row2 = create_row(null) + + checkEvaluation(StringSpace(Literal(2)), " ", row1) + checkEvaluation(StringSpace(Literal(-1)), "", row1) + checkEvaluation(StringSpace(Literal(0)), "", row1) + checkEvaluation(StringSpace(s1), " ", row1) + checkEvaluation(StringSpace(s1), null, row2) + } + + test("SPLIT") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val row1 = create_row("aa2bb3cc", "[1-9]+") + + checkEvaluation( + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) + checkEvaluation( + StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4da9ffc495e17..08bf37a5c223c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1626,6 +1626,179 @@ object functions { */ def ascii(columnName: String): Column = ascii(Column(columnName)) + /** + * Trim the spaces from both ends for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def trim(e: Column): Column = StringTrim(e.expr) + + /** + * Trim the spaces from both ends for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def trim(columnName: String): Column = trim(Column(columnName)) + + /** + * Trim the spaces from left end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(e: Column): Column = StringTrimLeft(e.expr) + + /** + * Trim the spaces from left end for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(columnName: String): Column = ltrim(Column(columnName)) + + /** + * Trim the spaces from right end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(e: Column): Column = StringTrimRight(e.expr) + + /** + * Trim the spaces from right end for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(columnName: String): Column = rtrim(Column(columnName)) + + /** + * Format strings in printf-style. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: Column, arguments: Column*): Column = { + StringFormat((format +: arguments).map(_.expr): _*) + } + + /** + * Format strings in printf-style. + * NOTE: `format` is the string value of the formatter, not column name. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: String, arguNames: String*): Column = { + StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) + } + + /** + * Locate the position of the first occurrence of substr value in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(substr: String, sub: String): Column = instr(Column(substr), Column(sub)) + + /** + * Locate the position of the first occurrence of substr column in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(substr: Column, sub: Column): Column = StringInstr(substr.expr, sub.expr) + + /** + * Locate the position of the first occurrence of substr. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String): Column = { + locate(Column(substr), Column(str)) + } + + /** + * Locate the position of the first occurrence of substr. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column): Column = { + new StringLocate(substr.expr, str.expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String, pos: String): Column = { + locate(Column(substr), Column(str), Column(pos)) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column, pos: Column): Column = { + StringLocate(substr.expr, str.expr, pos.expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column, pos: Int): Column = { + StringLocate(substr.expr, str.expr, lit(pos).expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String, pos: Int): Column = { + locate(Column(substr), Column(str), lit(pos)) + } + /** * Computes the specified value from binary to a base64 string. * @@ -1658,6 +1831,46 @@ object functions { */ def unbase64(columnName: String): Column = unbase64(Column(columnName)) + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: String, len: String, pad: String): Column = { + lpad(Column(str), Column(len), Column(pad)) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Column, pad: Column): Column = { + StringLPad(str.expr, len.expr, pad.expr) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Int, pad: Column): Column = { + StringLPad(str.expr, lit(len).expr, pad.expr) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: String, len: Int, pad: String): Column = { + lpad(Column(str), len, Column(pad)) + } + /** * Computes the first argument into a binary from a string using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). @@ -1702,6 +1915,146 @@ object functions { def decode(columnName: String, charset: String): Column = decode(Column(columnName), charset) + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: String, len: String, pad: String): Column = { + rpad(Column(str), Column(len), Column(pad)) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Column, pad: Column): Column = { + StringRPad(str.expr, len.expr, pad.expr) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: String, len: Int, pad: String): Column = { + rpad(Column(str), len, Column(pad)) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Int, pad: Column): Column = { + StringRPad(str.expr, lit(len).expr, pad.expr) + } + + /** + * Repeat the string value of the specified column n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(strColumn: String, timesColumn: String): Column = { + repeat(Column(strColumn), Column(timesColumn)) + } + + /** + * Repeat the string expression value n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, times: Column): Column = { + StringRepeat(str.expr, times.expr) + } + + /** + * Repeat the string value of the specified column n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(strColumn: String, times: Int): Column = { + repeat(Column(strColumn), times) + } + + /** + * Repeat the string expression value n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, times: Int): Column = { + StringRepeat(str.expr, lit(times).expr) + } + + /** + * Splits str around pattern (pattern is a regular expression). + * + * @group string_funcs + * @since 1.5.0 + */ + def split(strColumnName: String, pattern: String): Column = { + split(Column(strColumnName), pattern) + } + + /** + * Splits str around pattern (pattern is a regular expression). + * NOTE: pattern is a string represent the regular expression. + * + * @group string_funcs + * @since 1.5.0 + */ + def split(str: Column, pattern: String): Column = { + StringSplit(str.expr, lit(pattern).expr) + } + + /** + * Reversed the string for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: String): Column = { + reverse(Column(str)) + } + + /** + * Reversed the string for the specified value. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: Column): Column = { + StringReverse(str.expr) + } + + /** + * Make a n spaces of string. + * + * @group string_funcs + * @since 1.5.0 + */ + def space(n: String): Column = { + space(Column(n)) + } + + /** + * Make a n spaces of string. + * + * @group string_funcs + * @since 1.5.0 + */ + def space(n: Column): Column = { + StringSpace(n.expr) + } ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index afba28515e032..173280375c411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -209,21 +209,14 @@ class DataFrameFunctionsSuite extends QueryTest { } test("string length function") { + val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( - nullStrings.select(strlen($"s"), strlen("s")), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l, l) - }) + df.select(strlen($"a"), strlen("b")), + Row(3, 0)) checkAnswer( - nullStrings.selectExpr("length(s)"), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l) - }) + df.selectExpr("length(a)", "length(b)"), + Row(3, 0)) } test("Levenshtein distance") { @@ -273,4 +266,119 @@ class DataFrameFunctionsSuite extends QueryTest { Row(bytes, "大千世界")) // scalastyle:on } + + test("string trim functions") { + val df = Seq((" example ", "")).toDF("a", "b") + + checkAnswer( + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), + Row("example ", " example", "example")) + + checkAnswer( + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), + Row("example ", " example", "example")) + } + + test("string formatString function") { + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df.selectExpr("printf(a, b, c)"), + Row("aa123cc")) + } + + test("string instr function") { + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") + + checkAnswer( + df.select(instr($"a", $"b"), instr("a", "b")), + Row(1, 1)) + + checkAnswer( + df.selectExpr("instr(a, b)"), + Row(1)) + } + + test("string locate function") { + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + + checkAnswer( + df.select( + locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), + locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), + Row(1, 1, 2, 2, 2, 2)) + + checkAnswer( + df.selectExpr("locate(b, a)", "locate(b, a, d)"), + Row(1, 2)) + } + + test("string padding functions") { + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") + + checkAnswer( + df.select( + lpad($"a", $"b", $"c"), rpad("a", "b", "c"), + lpad($"a", 1, $"c"), rpad("a", 1, "c")), + Row("???hi", "hi???", "h", "h")) + + checkAnswer( + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), + Row("???hi", "hi???", "h", "h")) + } + + test("string repeat function") { + val df = Seq(("hi", 2)).toDF("a", "b") + + checkAnswer( + df.select( + repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), + Row("hihi", "hihi", "hihi", "hihi")) + + checkAnswer( + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), + Row("hihi", "hihi")) + } + + test("string reverse function") { + val df = Seq(("hi", "hhhi")).toDF("a", "b") + + checkAnswer( + df.select(reverse($"a"), reverse("b")), + Row("ih", "ihhh")) + + checkAnswer( + df.selectExpr("reverse(b)"), + Row("ihhh")) + } + + test("string space function") { + val df = Seq((2, 3)).toDF("a", "b") + + checkAnswer( + df.select(space($"a"), space("b")), + Row(" ", " ")) + + checkAnswer( + df.selectExpr("space(b)"), + Row(" ")) + } + + test("string split function") { + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select( + split($"a", "[1-9]+"), + split("a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+')"), + Row(Seq("aa", "bb", "cc"))) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 847d80ad583f6..60d050b0a0c97 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -25,6 +25,7 @@ import static org.apache.spark.unsafe.PlatformDependent.*; + /** * A UTF-8 String for internal Spark use. *

@@ -204,6 +205,196 @@ public UTF8String toLowerCase() { return fromString(toString().toLowerCase()); } + /** + * Copy the bytes from the current UTF8String, and make a new UTF8String. + * @param start the start position of the current UTF8String in bytes. + * @param end the end position of the current UTF8String in bytes. + * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. + */ + private UTF8String copyUTF8String(int start, int end) { + int len = end - start + 1; + byte[] newBytes = new byte[len]; + copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + return UTF8String.fromBytes(newBytes); + } + + public UTF8String trim() { + int s = 0; + int e = this.numBytes - 1; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (s > e) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, e); + } + } + + public UTF8String trimLeft() { + int s = 0; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + if (s == this.numBytes) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, this.numBytes - 1); + } + } + + public UTF8String trimRight() { + int e = numBytes - 1; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (e < 0) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(0, e); + } + } + + public UTF8String reverse() { + byte[] bytes = getBytes(); + byte[] result = new byte[bytes.length]; + + int i = 0; // position in byte + while (i < numBytes) { + int len = numBytesForFirstByte(getByte(i)); + System.arraycopy(bytes, i, result, result.length - i - len, len); + + i += len; + } + + return UTF8String.fromBytes(result); + } + + public UTF8String repeat(int times) { + if (times <=0) { + return fromBytes(new byte[0]); + } + + byte[] newBytes = new byte[numBytes * times]; + System.arraycopy(getBytes(), 0, newBytes, 0, numBytes); + + int copied = 1; + while (copied < times) { + int toCopy = Math.min(copied, times - copied); + System.arraycopy(newBytes, 0, newBytes, copied * numBytes, numBytes * toCopy); + copied += toCopy; + } + + return UTF8String.fromBytes(newBytes); + } + + /** + * Returns the position of the first occurrence of substr in + * current string from the specified position (0-based index). + * + * @param v the string to be searched + * @param start the start position of the current string for searching + * @return the position of the first occurrence of substr, if not found, -1 returned. + */ + public int indexOf(UTF8String v, int start) { + if (v.numBytes() == 0) { + return 0; + } + + // locate to the start position. + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + do { + if (i + v.numBytes > numBytes) { + return -1; + } + if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + return c; + } + i += numBytesForFirstByte(getByte(i)); + c += 1; + } while(i < numBytes); + + return -1; + } + + /** + * Returns str, right-padded with pad to a length of len + * For example: + * ('hi', 5, '??') => 'hi???' + * ('hi', 1, '??') => 'h' + */ + public UTF8String rpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + System.arraycopy(getBytes(), 0, data, 0, this.numBytes); + int offset = this.numBytes; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + + return UTF8String.fromBytes(data); + } + } + + /** + * Returns str, left-padded with pad to a length of len. + * For example: + * ('hi', 5, '??') => '???hi' + * ('hi', 1, '??') => 'h' + */ + public UTF8String lpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + + int offset = 0; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + offset += remain.numBytes; + System.arraycopy(getBytes(), 0, data, offset, numBytes()); + + return UTF8String.fromBytes(data); + } + } + @Override public String toString() { try { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index fb463ba17f50b..694bdc29f39d1 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -121,12 +121,94 @@ public void endsWith() { @Test public void substring() { - assertEquals(fromString("hello").substring(0, 0), fromString("")); - assertEquals(fromString("hello").substring(1, 3), fromString("el")); - assertEquals(fromString("数据砖头").substring(0, 1), fromString("数")); - assertEquals(fromString("数据砖头").substring(1, 3), fromString("据砖")); - assertEquals(fromString("数据砖头").substring(3, 5), fromString("头")); - assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷")); + assertEquals(fromString(""), fromString("hello").substring(0, 0)); + assertEquals(fromString("el"), fromString("hello").substring(1, 3)); + assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1)); + assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); + assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5)); + assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2)); + } + + @Test + public void trims() { + assertEquals(fromString("hello"), fromString(" hello ").trim()); + assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + + assertEquals(fromString(""), fromString(" ").trim()); + assertEquals(fromString(""), fromString(" ").trimLeft()); + assertEquals(fromString(""), fromString(" ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + } + + @Test + public void indexOf() { + assertEquals(0, fromString("").indexOf(fromString(""), 0)); + assertEquals(-1, fromString("").indexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").indexOf(fromString(""), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("l"), 0)); + assertEquals(3, fromString("hello").indexOf(fromString("l"), 3)); + assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("ll"), 0)); + assertEquals(-1, fromString("hello").indexOf(fromString("ll"), 4)); + assertEquals(1, fromString("数据砖头").indexOf(fromString("据砖"), 0)); + assertEquals(-1, fromString("数据砖头").indexOf(fromString("数"), 3)); + assertEquals(0, fromString("数据砖头").indexOf(fromString("数"), 0)); + assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); + } + + @Test + public void reverse() { + assertEquals(fromString("olleh"), fromString("hello").reverse()); + assertEquals(fromString(""), fromString("").reverse()); + assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); + assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); + } + + @Test + public void repeat() { + assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5)); + assertEquals(fromString("数d"), fromString("数d").repeat(1)); + assertEquals(fromString(""), fromString("数d").repeat(-1)); + } + + @Test + public void pad() { + assertEquals(fromString("hel"), fromString("hello").lpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").lpad(5, fromString("????"))); + assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????"))); + assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????"))); + assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").lpad(7, fromString("?????"))); + + assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????"))); + assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????"))); + assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????"))); + assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").rpad(7, fromString("?????"))); + + + assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????"))); + assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????"))); + assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); + assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); + assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); + assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者"))); + + assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); + assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); + assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); + assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); } @Test From 7ce3b818fb1ba3f291eda58988e4808e999cae3a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 9 Jul 2015 13:19:36 -0700 Subject: [PATCH 224/274] [MINOR] [STREAMING] Fix log statements in ReceiverSupervisorImpl Log statements incorrectly showed that the executor was being stopped when receiver was being stopped. Author: Tathagata Das Closes #7328 from tdas/fix-log and squashes the following commits: 9cc6e99 [Tathagata Das] Fix log statements. --- .../spark/streaming/receiver/ReceiverSupervisor.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 33be067ebdaf2..eeb14ca3a49e9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -182,12 +182,12 @@ private[streaming] abstract class ReceiverSupervisor( /** Wait the thread until the supervisor is stopped */ def awaitTermination() { + logInfo("Waiting for receiver to be stopped") stopLatch.await() - logInfo("Waiting for executor stop is over") if (stoppingError != null) { - logError("Stopped executor with error: " + stoppingError) + logError("Stopped receiver with error: " + stoppingError) } else { - logWarning("Stopped executor without error") + logInfo("Stopped receiver without error") } if (stoppingError != null) { throw stoppingError From 930fe95350f8865e2af2d7afa5b717210933cd43 Mon Sep 17 00:00:00 2001 From: xutingjun Date: Thu, 9 Jul 2015 13:21:10 -0700 Subject: [PATCH 225/274] [SPARK-8953] SPARK_EXECUTOR_CORES is not read in SparkSubmit The configuration ```SPARK_EXECUTOR_CORES``` won't put into ```SparkConf```, so it has no effect to the dynamic executor allocation. Author: xutingjun Closes #7322 from XuTingjun/SPARK_EXECUTOR_CORES and squashes the following commits: 2cafa89 [xutingjun] make SPARK_EXECUTOR_CORES has effect to dynamicAllocation --- .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 73ab18332feb4..6e3c0b21b33c2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -162,6 +162,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull executorCores = Option(executorCores) .orElse(sparkProperties.get("spark.executor.cores")) + .orElse(env.get("SPARK_EXECUTOR_CORES")) .orNull totalExecutorCores = Option(totalExecutorCores) .orElse(sparkProperties.get("spark.cores.max")) From 88bf430331eef3c02438ca441616034486e15789 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 9 Jul 2015 13:22:17 -0700 Subject: [PATCH 226/274] [SPARK-7419] [STREAMING] [TESTS] Fix CheckpointSuite.recovery with file input stream Fix this failure: https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/2886/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.3,label=centos/testReport/junit/org.apache.spark.streaming/CheckpointSuite/recovery_with_file_input_stream/ To reproduce this failure, you can add `Thread.sleep(2000)` before this line https://github.com/apache/spark/blob/a9c4e29950a14e32acaac547e9a0e8879fd37fc9/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala#L477 Author: zsxwing Closes #7323 from zsxwing/SPARK-7419 and squashes the following commits: b3caf58 [zsxwing] Fix CheckpointSuite.recovery with file input stream --- .../spark/streaming/CheckpointSuite.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6b0a3f91d4d06..6a94928076236 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -424,11 +424,11 @@ class CheckpointSuite extends TestSuiteBase { } } } - clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { // Wait until all files have been recorded and all batches have started assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } + clock.advance(batchDuration.milliseconds) // Wait for a checkpoint to be written eventually(eventuallyTimeout) { assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) @@ -454,9 +454,12 @@ class CheckpointSuite extends TestSuiteBase { // recorded before failure were saved and successfully recovered logInfo("*********** RESTARTING ************") withStreamingContext(new StreamingContext(checkpointDir)) { ssc => - // So that the restarted StreamingContext's clock has gone forward in time since failure - ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString) - val oldClockTime = clock.getTimeMillis() + // "batchDuration.milliseconds * 3" has gone before restarting StreamingContext. And because + // the recovery time is read from the checkpoint time but the original clock doesn't align + // with the batch time, we need to add the offset "batchDuration.milliseconds / 2". + ssc.conf.set("spark.streaming.manualClock.jump", + (batchDuration.milliseconds / 2 + batchDuration.milliseconds * 3).toString) + val oldClockTime = clock.getTimeMillis() // 15000ms clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val batchCounter = new BatchCounter(ssc) val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] @@ -467,10 +470,10 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Verify that the clock has traveled forward to the expected time eventually(eventuallyTimeout) { - clock.getTimeMillis() === oldClockTime + assert(clock.getTimeMillis() === oldClockTime) } - // Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch) - val numBatchesAfterRestart = 4 + // There are 5 batches between 6000ms and 15000ms (inclusive). + val numBatchesAfterRestart = 5 eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === numBatchesAfterRestart) } @@ -483,7 +486,6 @@ class CheckpointSuite extends TestSuiteBase { assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - clock.advance(batchDuration.milliseconds) logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() From ebdf58538058e57381c04b6725d4be0c37847ed3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 9 Jul 2015 13:25:11 -0700 Subject: [PATCH 227/274] [SPARK-2017] [UI] Stage page hangs with many tasks (This reopens a patch that was closed in the past: #6248) When you view the stage page while running the following: ``` sc.parallelize(1 to X, 10000).count() ``` The page never loads, the job is stalled, and you end up running into an OOM: ``` HTTP ERROR 500 Problem accessing /stages/stage/. Reason: Server Error Caused by: java.lang.OutOfMemoryError: Java heap space at java.util.Arrays.copyOf(Arrays.java:2367) at java.lang.AbstractStringBuilder.expandCapacity(AbstractStringBuilder.java:130) ``` This patch compresses Jetty responses in gzip. The correct long-term fix is to add pagination. Author: Andrew Or Closes #7296 from andrewor14/gzip-jetty and squashes the following commits: a051c64 [Andrew Or] Use GZIP to compress Jetty responses --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 06e616220c706..f413c1d37fbb6 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -210,10 +210,16 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection - collection.setHandlers(handlers.toArray) addFilters(handlers, conf) + val collection = new ContextHandlerCollection + val gzipHandlers = handlers.map { h => + val gzipHandler = new GzipHandler + gzipHandler.setHandler(h) + gzipHandler + } + collection.setHandlers(gzipHandlers.toArray) + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) From c4830598b271cc6390d127bd4cf8ab02b28792e0 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Thu, 9 Jul 2015 13:26:46 -0700 Subject: [PATCH 228/274] [SPARK-6287] [MESOS] Add dynamic allocation to the coarse-grained Mesos scheduler This is largely based on extracting the dynamic allocation parts from tnachen's #3861. Author: Iulian Dragos Closes #4984 from dragos/issue/mesos-coarse-dynamicAllocation and squashes the following commits: 39df8cd [Iulian Dragos] Update tests to latest changes in core. 9d2c9fa [Iulian Dragos] Remove adjustment of executorLimitOption in doKillExecutors. 8b00f52 [Iulian Dragos] Latest round of reviews. 0cd00e0 [Iulian Dragos] Add persistent shuffle directory 15c45c1 [Iulian Dragos] Add dynamic allocation to the Spark coarse-grained scheduler. --- .../scala/org/apache/spark/SparkContext.scala | 19 +- .../mesos/CoarseMesosSchedulerBackend.scala | 136 +++++++++++--- .../cluster/mesos/MesosSchedulerUtils.scala | 4 +- .../spark/storage/DiskBlockManager.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 45 +++-- .../CoarseMesosSchedulerBackendSuite.scala | 175 ++++++++++++++++++ 6 files changed, 331 insertions(+), 56 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d2547eeff2b4e..82704b1ab2189 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -532,7 +532,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _executorAllocationManager = if (dynamicAllocationEnabled) { assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") + "Dynamic allocation of executors is currently only supported in YARN and Mesos mode") Some(new ExecutorAllocationManager(this, listenerBus, _conf)) } else { None @@ -853,7 +853,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions).setName(path) } - /** * :: Experimental :: * @@ -1364,10 +1363,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Return whether dynamically adjusting the amount of resources allocated to - * this application is supported. This is currently only available for YARN. + * this application is supported. This is currently only available for YARN + * and Mesos coarse-grained mode. */ - private[spark] def supportDynamicAllocation = - master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) + private[spark] def supportDynamicAllocation: Boolean = { + (master.contains("yarn") + || master.contains("mesos") + || _conf.getBoolean("spark.dynamicAllocation.testing", false)) + } /** * :: DeveloperApi :: @@ -1385,7 +1388,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestTotalExecutors(numExecutors) @@ -1403,7 +1406,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestExecutors(numAdditionalExecutors) @@ -1421,7 +1424,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { assert(supportDynamicAllocation, - "Killing executors is currently only supported in YARN mode") + "Killing executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.killExecutors(executorIds) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index b68f8c7685eba..cbade131494bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,11 +18,14 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList} +import java.util.{List => JList, Collections} +import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.collect.HashBiMap +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} @@ -60,9 +63,27 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] + // How many times tasks on each slave failed + val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + + /** + * The total number of executors we aim to have. Undefined when not using dynamic allocation + * and before the ExecutorAllocatorManager calls [[doRequesTotalExecutors]]. + */ + private var executorLimitOption: Option[Int] = None + + /** + * Return the current executor limit, which may be [[Int.MaxValue]] + * before properly initialized. + */ + private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) + + private val pendingRemovedSlaveIds = new HashSet[String] + // private lock object protecting mutable state above. Using the intrinsic lock + // may lead to deadlocks since the superclass might also try to lock + private val stateLock = new ReentrantLock val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) @@ -86,7 +107,7 @@ private[spark] class CoarseMesosSchedulerBackend( startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } - def createCommand(offer: Offer, numCores: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -120,10 +141,6 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) @@ -133,7 +150,7 @@ private[spark] class CoarseMesosSchedulerBackend( command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + - s" --driver-url $driverUrl" + + s" --driver-url $driverURL" + s" --executor-id ${offer.getSlaveId.getValue}" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + @@ -142,11 +159,12 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head + val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + - s" --driver-url $driverUrl" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --driver-url $driverURL" + + s" --executor-id $executorId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -155,6 +173,17 @@ private[spark] class CoarseMesosSchedulerBackend( command.build() } + protected def driverURL: String = { + if (conf.contains("spark.testing")) { + "driverURL" + } else { + sc.env.rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + } + } + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { @@ -172,17 +201,18 @@ private[spark] class CoarseMesosSchedulerBackend( * unless we've already launched more than we wanted to. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { + stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() for (offer <- offers) { val offerAttributes = toAttributeMap(offer.getAttributesList) val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) - val slaveId = offer.getSlaveId.toString + val slaveId = offer.getSlaveId.getValue val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue - if (meetsConstraints && + if (taskIdToSlaveId.size < executorLimit && totalCoresAcquired < maxCores && + meetsConstraints && mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && @@ -197,7 +227,7 @@ private[spark] class CoarseMesosSchedulerBackend( val task = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", calculateTotalMemory(sc))) @@ -209,7 +239,9 @@ private[spark] class CoarseMesosSchedulerBackend( // accept the offer and launch the task logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.launchTasks(List(offer.getId), List(task.build()), filters) + d.launchTasks( + Collections.singleton(offer.getId), + Collections.singleton(task.build()), filters) } else { // Decline the offer logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") @@ -224,7 +256,7 @@ private[spark] class CoarseMesosSchedulerBackend( val taskId = status.getTaskId.getValue.toInt val state = status.getState logInfo("Mesos task " + taskId + " is now " + state) - synchronized { + stateLock.synchronized { if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId @@ -242,8 +274,9 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } + executorTerminated(d, slaveId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node - mesosDriver.reviveOffers() + d.reviveOffers() } } } @@ -262,18 +295,39 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") + /** + * Called when a slave is lost or a Mesos task finished. Update local view on + * what tasks are running and remove the terminated slave from the list of pending + * slave IDs that we might have asked to be killed. It also notifies the driver + * that an executor was removed. + */ + private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + stateLock.synchronized { + if (slaveIdsWithExecutors.contains(slaveId)) { + val slaveIdToTaskId = taskIdToSlaveId.inverse() + if (slaveIdToTaskId.contains(slaveId)) { + val taskId: Int = slaveIdToTaskId.get(slaveId) + taskIdToSlaveId.remove(taskId) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + } + // TODO: This assumes one Spark executor per Mesos slave, + // which may no longer be true after SPARK-5095 + pendingRemovedSlaveIds -= slaveId + slaveIdsWithExecutors -= slaveId } } } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { + private def sparkExecutorId(slaveId: String, taskId: String): String = { + s"$slaveId/$taskId" + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { + logInfo("Mesos slave lost: " + slaveId.getValue) + executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) + } + + override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } @@ -284,4 +338,34 @@ private[spark] class CoarseMesosSchedulerBackend( super.applicationId } + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + // We don't truly know if we can fulfill the full amount of executors + // since at coarse grain it depends on the amount of slaves available. + logInfo("Capping the total amount of executors to " + requestedTotal) + executorLimitOption = Some(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + if (mesosDriver == null) { + logWarning("Asked to kill executors before the Mesos driver was started.") + return false + } + + val slaveIdToTaskId = taskIdToSlaveId.inverse() + for (executorId <- executorIds) { + val slaveId = executorId.split("/")(0) + if (slaveIdToTaskId.contains(slaveId)) { + mesosDriver.killTask( + TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) + pendingRemovedSlaveIds += slaveId + } else { + logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + } + } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d8a8c848bb4d1..925702e63afd3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConversions._ import scala.util.control.NonFatal import com.google.common.base.Splitter -import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler} +import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} import org.apache.mesos.Protos._ import org.apache.mesos.protobuf.GeneratedMessage import org.apache.spark.{Logging, SparkContext} @@ -39,7 +39,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { private final val registerLatch = new CountDownLatch(1) // Driver for talking to Mesos - protected var mesosDriver: MesosSchedulerDriver = null + protected var mesosDriver: SchedulerDriver = null /** * Starts the MesosSchedulerDriver with the provided information. This method returns diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 91ef86389a0c3..5f537692a16c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -124,10 +124,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon (blockId, getFile(blockId)) } + /** + * Create local directories for storing block data. These directories are + * located inside configured local directories and won't + * be deleted on JVM exit when using the external shuffle service. + */ private def createLocalDirs(conf: SparkConf): Array[File] = { - Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => + Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") + Utils.chmod700(localDir) logInfo(s"Created local directory at $localDir") Some(localDir) } catch { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 944560a91354a..b6b932104a94d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -733,7 +733,12 @@ private[spark] object Utils extends Logging { localRootDirs } - private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + /** + * Return the configured local directories where Spark can write files. This + * method does not create any directories on its own, it only encapsulates the + * logic of locating the local directories according to deployment mode. + */ + def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has @@ -749,27 +754,29 @@ private[spark] object Utils extends Logging { Option(conf.getenv("SPARK_LOCAL_DIRS")) .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) .split(",") - .flatMap { root => - try { - val rootDir = new File(root) - if (rootDir.exists || rootDir.mkdirs()) { - val dir = createTempDir(root) - chmod700(dir) - Some(dir.getAbsolutePath) - } else { - logError(s"Failed to create dir in $root. Ignoring this directory.") - None - } - } catch { - case e: IOException => - logError(s"Failed to create local root dir in $root. Ignoring this directory.") - None - } - } - .toArray } } + private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + getConfiguredLocalDirs(conf).flatMap { root => + try { + val rootDir = new File(root) + if (rootDir.exists || rootDir.mkdirs()) { + val dir = createTempDir(root) + chmod700(dir) + Some(dir.getAbsolutePath) + } else { + logError(s"Failed to create dir in $root. Ignoring this directory.") + None + } + } catch { + case e: IOException => + logError(s"Failed to create local root dir in $root. Ignoring this directory.") + None + } + }.toArray + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(conf: SparkConf): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..3f1692917a357 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util +import java.util.Collections + +import org.apache.mesos.Protos.Value.Scalar +import org.apache.mesos.Protos._ +import org.apache.mesos.SchedulerDriver +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.Matchers +import org.scalatest.mock.MockitoSugar +import org.scalatest.BeforeAndAfter + +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} + +class CoarseMesosSchedulerBackendSuite extends SparkFunSuite + with LocalSparkContext + with MockitoSugar + with BeforeAndAfter { + + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(OfferID.newBuilder() + .setValue(offerId).build()) + .setFrameworkId(FrameworkID.newBuilder() + .setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) + .setHostname(s"host${slaveId}") + .build() + } + + private def createSchedulerBackend( + taskScheduler: TaskSchedulerImpl, + driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") { + mesosDriver = driver + markRegistered() + } + backend.start() + backend + } + + var sparkConf: SparkConf = _ + + before { + sparkConf = (new SparkConf) + .setMaster("local[*]") + .setAppName("test-mesos-dynamic-alloc") + .setSparkHome("/path") + + sc = new SparkContext(sparkConf) + } + + test("mesos supports killing and limiting executors") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc).toInt + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) + + val taskID0 = TaskID.newBuilder().setValue("0").build() + + backend.resourceOffers(driver, mesosOffers) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + // simulate the allocation manager down-scaling executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("s1/0"))) + verify(driver, times(1)).killTask(taskID0) + + val mesosOffers2 = new java.util.ArrayList[Offer] + mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) + backend.resourceOffers(driver, mesosOffers2) + + verify(driver, times(1)) + .declineOffer(OfferID.newBuilder().setValue("o2").build()) + + // Verify we didn't launch any new executor + assert(backend.slaveIdsWithExecutors.size === 1) + + backend.doRequestTotalExecutors(2) + backend.resourceOffers(driver, mesosOffers2) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + assert(backend.slaveIdsWithExecutors.size === 2) + backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) + assert(backend.slaveIdsWithExecutors.size === 1) + } + + test("mesos supports killing and relaunching tasks with executors") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc).toInt + 1024 + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + val offer1 = createOffer("o1", "s1", minMem, minCpu) + mesosOffers.add(offer1) + + val offer2 = createOffer("o2", "s1", minMem, 1); + + backend.resourceOffers(driver, mesosOffers) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer1.getId)), + anyObject(), + anyObject[Filters]) + + // Simulate task killed, executor no longer running + val status = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue("0").build()) + .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) + .setState(TaskState.TASK_KILLED) + .build + + backend.statusUpdate(driver, status) + assert(!backend.slaveIdsWithExecutors.contains("s1")) + + mesosOffers.clear() + mesosOffers.add(offer2) + backend.resourceOffers(driver, mesosOffers) + assert(backend.slaveIdsWithExecutors.contains("s1")) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer2.getId)), + anyObject(), + anyObject[Filters]) + + verify(driver, times(1)).reviveOffers() + } +} From 1f6b0b1234cc03aa2e07aea7fec2de7563885238 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 9 Jul 2015 13:48:29 -0700 Subject: [PATCH 229/274] [SPARK-8701] [STREAMING] [WEBUI] Add input metadata in the batch page This PR adds `metadata` to `InputInfo`. `InputDStream` can report its metadata for a batch and it will be shown in the batch page. For example, ![screen shot](https://cloud.githubusercontent.com/assets/1000778/8403741/d6ffc7e2-1e79-11e5-9888-c78c1575123a.png) FileInputDStream will display the new files for a batch, and DirectKafkaInputDStream will display its offset ranges. Author: zsxwing Closes #7081 from zsxwing/input-metadata and squashes the following commits: f7abd9b [zsxwing] Revert the space changes in project/MimaExcludes.scala d906209 [zsxwing] Merge branch 'master' into input-metadata 74762da [zsxwing] Fix MiMa tests 7903e33 [zsxwing] Merge branch 'master' into input-metadata 450a46c [zsxwing] Address comments 1d94582 [zsxwing] Raname InputInfo to StreamInputInfo and change "metadata" to Map[String, Any] d496ae9 [zsxwing] Add input metadata in the batch page --- .../kafka/DirectKafkaInputDStream.scala | 23 ++++++++-- .../spark/streaming/kafka/OffsetRange.scala | 2 +- project/MimaExcludes.scala | 6 +++ .../streaming/dstream/FileInputDStream.scala | 10 ++++- .../dstream/ReceiverInputDStream.scala | 4 +- .../spark/streaming/scheduler/BatchInfo.scala | 9 ++-- .../scheduler/InputInfoTracker.scala | 38 +++++++++++++--- .../streaming/scheduler/JobGenerator.scala | 3 +- .../spark/streaming/scheduler/JobSet.scala | 4 +- .../apache/spark/streaming/ui/BatchPage.scala | 43 +++++++++++++++++-- .../spark/streaming/ui/BatchUIData.scala | 8 ++-- .../ui/StreamingJobProgressListener.scala | 5 ++- .../streaming/StreamingListenerSuite.scala | 6 +-- .../spark/streaming/TestSuiteBase.scala | 2 +- .../scheduler/InputInfoTrackerSuite.scala | 8 ++-- .../StreamingJobProgressListenerSuite.scala | 28 ++++++------ 16 files changed, 148 insertions(+), 51 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 876456c964770..48a1933d92f85 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka import scala.annotation.tailrec import scala.collection.mutable -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.StreamInputInfo /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -119,8 +119,23 @@ class DirectKafkaInputDStream[ val rdd = KafkaRDD[K, V, U, T, R]( context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) - // Report the record number of this batch interval to InputInfoTracker. - val inputInfo = InputInfo(id, rdd.count) + // Report the record number and metadata of this batch interval to InputInfoTracker. + val offsetRanges = currentOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + } + val description = offsetRanges.filter { offsetRange => + // Don't display empty ranges. + offsetRange.fromOffset != offsetRange.untilOffset + }.map { offsetRange => + s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + }.mkString("\n") + // Copy offsetRanges to immutable.List to prevent from being modified by the user + val metadata = Map( + "offsets" -> offsetRanges.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> description) + val inputInfo = StreamInputInfo(id, rdd.count, metadata) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 2675042666304..f326e7f1f6f8d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -75,7 +75,7 @@ final class OffsetRange private( } override def toString(): String = { - s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])" } /** this is to avoid ClassNotFoundException during checkpoint restore */ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 821aadd477ef3..79089aae2a37c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -77,6 +77,12 @@ object MimaExcludes { // SPARK-8914 Remove RDDApi ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.RDDApi") + ) ++ Seq( + // SPARK-8701 Add input metadata in the batch page. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo") ) case v if v.startsWith("1.4") => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 86a8e2beff57c..dd4da9d9ca6a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.StreamInputInfo import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} /** @@ -144,7 +145,14 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) batchTimeToSelectedFiles += ((validTime, newFiles)) recentlySelectedFiles ++= newFiles - Some(filesToRDD(newFiles)) + val rdds = Some(filesToRDD(newFiles)) + // Copy newFiles to immutable.List to prevent from being modified by the user + val metadata = Map( + "files" -> newFiles.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> newFiles.mkString("\n")) + val inputInfo = StreamInputInfo(id, 0, metadata) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + rdds } /** Clear the old time-to-files mappings along with old RDDs */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index e76e7eb0dea19..a50f0efc030ce 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -24,7 +24,7 @@ import org.apache.spark.storage.BlockId import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.StreamInputInfo import org.apache.spark.streaming.util.WriteAheadLogUtils /** @@ -70,7 +70,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray // Register the input blocks information into InputInfoTracker - val inputInfo = InputInfo(id, blockInfos.flatMap(_.numRecords).sum) + val inputInfo = StreamInputInfo(id, blockInfos.flatMap(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) if (blockInfos.nonEmpty) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 5b9bfbf9b01e3..9922b6bc1201b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -24,7 +24,7 @@ import org.apache.spark.streaming.Time * :: DeveloperApi :: * Class having information on completed batches. * @param batchTime Time of the batch - * @param streamIdToNumRecords A map of input stream id to record number + * @param streamIdToInputInfo A map of input stream id to its input info * @param submissionTime Clock time of when jobs of this batch was submitted to * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing @@ -33,12 +33,15 @@ import org.apache.spark.streaming.Time @DeveloperApi case class BatchInfo( batchTime: Time, - streamIdToNumRecords: Map[Int, Long], + streamIdToInputInfo: Map[Int, StreamInputInfo], submissionTime: Long, processingStartTime: Option[Long], processingEndTime: Option[Long] ) { + @deprecated("Use streamIdToInputInfo instead", "1.5.0") + def streamIdToNumRecords: Map[Int, Long] = streamIdToInputInfo.mapValues(_.numRecords) + /** * Time taken for the first job of this batch to start processing from the time this batch * was submitted to the streaming scheduler. Essentially, it is @@ -63,5 +66,5 @@ case class BatchInfo( /** * The number of recorders received by the receivers in this batch. */ - def numRecords: Long = streamIdToNumRecords.values.sum + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 7c0db8a863c67..363c03d431f04 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -20,11 +20,34 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.streaming.{Time, StreamingContext} -/** To track the information of input stream at specified batch time. */ -private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { +/** + * :: DeveloperApi :: + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + */ +@DeveloperApi +case class StreamInputInfo( + inputStreamId: Int, numRecords: Long, metadata: Map[String, Any] = Map.empty) { require(numRecords >= 0, "numRecords must not be negative") + + def metadataDescription: Option[String] = + metadata.get(StreamInputInfo.METADATA_KEY_DESCRIPTION).map(_.toString) +} + +@DeveloperApi +object StreamInputInfo { + + /** + * The key for description in `StreamInputInfo.metadata`. + */ + val METADATA_KEY_DESCRIPTION: String = "Description" } /** @@ -34,12 +57,13 @@ private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging { // Map to track all the InputInfo related to specific batch time and input stream. - private val batchTimeToInputInfos = new mutable.HashMap[Time, mutable.HashMap[Int, InputInfo]] + private val batchTimeToInputInfos = + new mutable.HashMap[Time, mutable.HashMap[Int, StreamInputInfo]] /** Report the input information with batch time to the tracker */ - def reportInfo(batchTime: Time, inputInfo: InputInfo): Unit = synchronized { + def reportInfo(batchTime: Time, inputInfo: StreamInputInfo): Unit = synchronized { val inputInfos = batchTimeToInputInfos.getOrElseUpdate(batchTime, - new mutable.HashMap[Int, InputInfo]()) + new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + @@ -49,10 +73,10 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging } /** Get the all the input stream's information of specified batch time */ - def getInfo(batchTime: Time): Map[Int, InputInfo] = synchronized { + def getInfo(batchTime: Time): Map[Int, StreamInputInfo] = synchronized { val inputInfos = batchTimeToInputInfos.get(batchTime) // Convert mutable HashMap to immutable Map for the caller - inputInfos.map(_.toMap).getOrElse(Map[Int, InputInfo]()) + inputInfos.map(_.toMap).getOrElse(Map[Int, StreamInputInfo]()) } /** Cleanup the tracked input information older than threshold batch time */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 9f93d6cbc3c20..f5d41858646e4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -244,8 +244,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } match { case Success(jobs) => val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time) - val streamIdToNumRecords = streamIdToInputInfos.mapValues(_.numRecords) - jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToNumRecords)) + jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index e6be63b2ddbdc..95833efc9417f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,7 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - streamIdToNumRecords: Map[Int, Long] = Map.empty) { + streamIdToInputInfo: Map[Int, StreamInputInfo] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -64,7 +64,7 @@ case class JobSet( def toBatchInfo: BatchInfo = { new BatchInfo( time, - streamIdToNumRecords, + streamIdToInputInfo, submissionTime, if (processingStartTime >= 0 ) Some(processingStartTime) else None, if (processingEndTime >= 0 ) Some(processingEndTime) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index f75067669abe5..0c891662c264f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -17,11 +17,9 @@ package org.apache.spark.streaming.ui -import java.text.SimpleDateFormat -import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.{NodeSeq, Node, Text} +import scala.xml.{NodeSeq, Node, Text, Unparsed} import org.apache.commons.lang3.StringEscapeUtils @@ -303,6 +301,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { batchUIData.processingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") val formattedTotalDelay = batchUIData.totalDelay.map(SparkUIUtils.formatDuration).getOrElse("-") + val inputMetadatas = batchUIData.streamIdToInputInfo.values.flatMap { inputInfo => + inputInfo.metadataDescription.map(desc => inputInfo.inputStreamId -> desc) + }.toSeq val summary: NodeSeq =

    @@ -326,6 +327,13 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { Total delay: {formattedTotalDelay} + { + if (inputMetadatas.nonEmpty) { +
  • + Input Metadata:{generateInputMetadataTable(inputMetadatas)} +
  • + } + }
@@ -340,4 +348,33 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) } + + def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { + + + + + + + + + {inputMetadatas.flatMap(generateInputMetadataRow)} + +
InputMetadata
+ } + + def generateInputMetadataRow(inputMetadata: (Int, String)): Seq[Node] = { + val streamId = inputMetadata._1 + + + {streamingListener.streamName(streamId).getOrElse(s"Stream-$streamId")} + {metadataDescriptionToHTML(inputMetadata._2)} + + } + + private def metadataDescriptionToHTML(metadataDescription: String): Seq[Node] = { + // tab to 4 spaces and "\n" to "
" + Unparsed(StringEscapeUtils.escapeHtml4(metadataDescription). + replaceAllLiterally("\t", "    ").replaceAllLiterally("\n", "
")) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index a5514dfd71c9f..ae508c0e9577b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -19,14 +19,14 @@ package org.apache.spark.streaming.ui import org.apache.spark.streaming.Time -import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.streaming.scheduler.{BatchInfo, StreamInputInfo} import org.apache.spark.streaming.ui.StreamingJobProgressListener._ private[ui] case class OutputOpIdAndSparkJobId(outputOpId: OutputOpId, sparkJobId: SparkJobId) private[ui] case class BatchUIData( val batchTime: Time, - val streamIdToNumRecords: Map[Int, Long], + val streamIdToInputInfo: Map[Int, StreamInputInfo], val submissionTime: Long, val processingStartTime: Option[Long], val processingEndTime: Option[Long], @@ -58,7 +58,7 @@ private[ui] case class BatchUIData( /** * The number of recorders received by the receivers in this batch. */ - def numRecords: Long = streamIdToNumRecords.values.sum + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum } private[ui] object BatchUIData { @@ -66,7 +66,7 @@ private[ui] object BatchUIData { def apply(batchInfo: BatchInfo): BatchUIData = { new BatchUIData( batchInfo.batchTime, - batchInfo.streamIdToNumRecords, + batchInfo.streamIdToInputInfo, batchInfo.submissionTime, batchInfo.processingStartTime, batchInfo.processingEndTime diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 68e8ce98945e0..b77c555c68b8b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -192,7 +192,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) def receivedEventRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized { val _retainedBatches = retainedBatches val latestBatches = _retainedBatches.map { batchUIData => - (batchUIData.batchTime.milliseconds, batchUIData.streamIdToNumRecords) + (batchUIData.batchTime.milliseconds, batchUIData.streamIdToInputInfo.mapValues(_.numRecords)) } streamIds.map { streamId => val eventRates = latestBatches.map { @@ -205,7 +205,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def lastReceivedBatchRecords: Map[Int, Long] = synchronized { - val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.streamIdToNumRecords) + val lastReceivedBlockInfoOption = + lastReceivedBatch.map(_.streamIdToInputInfo.mapValues(_.numRecords)) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => streamIds.map { streamId => (streamId, lastReceivedBlockInfo.getOrElse(streamId, 0L)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 7bc7727a9fbe4..4bc1dd4a30fc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -59,7 +59,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosSubmitted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) @@ -77,7 +77,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosStarted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) @@ -98,7 +98,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosCompleted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 31b1aebf6a8ec..0d58a7b54412f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -76,7 +76,7 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], } // Report the input data's information to InputInfoTracker for testing - val inputInfo = InputInfo(id, selectedInput.length.toLong) + val inputInfo = StreamInputInfo(id, selectedInput.length.toLong) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 2e210397fe7c7..f5248acf712b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -46,8 +46,8 @@ class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { val streamId1 = 0 val streamId2 = 1 val time = Time(0L) - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId2, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId2, 300L) inputInfoTracker.reportInfo(time, inputInfo1) inputInfoTracker.reportInfo(time, inputInfo2) @@ -63,8 +63,8 @@ class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { val inputInfoTracker = new InputInfoTracker(ssc) val streamId1 = 0 - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId1, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId1, 300L) inputInfoTracker.reportInfo(Time(0), inputInfo1) inputInfoTracker.reportInfo(Time(1), inputInfo2) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index c9175d61b1f49..40dc1fb601bd0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -49,10 +49,12 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -64,7 +66,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -94,7 +96,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoStarted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoStarted.processingDelay) batchUIData.get.totalDelay should be (batchInfoStarted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map(0 -> 300L, 1 -> 300L)) + batchUIData.get.streamIdToInputInfo should be (Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test")))) batchUIData.get.numRecords should be(600) batchUIData.get.outputOpIdSparkJobIdPairs should be Seq(OutputOpIdAndSparkJobId(0, 0), @@ -103,7 +107,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -141,9 +145,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) @@ -182,7 +186,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoSubmitted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoSubmitted.processingDelay) batchUIData.get.totalDelay should be (batchInfoSubmitted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map.empty) + batchUIData.get.streamIdToInputInfo should be (Map.empty) batchUIData.get.numRecords should be (0) batchUIData.get.outputOpIdSparkJobIdPairs should be (Seq(OutputOpIdAndSparkJobId(0, 0))) @@ -211,14 +215,14 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) for (_ <- 0 until 2 * limit) { - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) // onJobStart @@ -235,7 +239,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart4) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } From 3ccebf36c5abe04702d4cf223552a94034d980fb Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 9 Jul 2015 13:54:44 -0700 Subject: [PATCH 230/274] [SPARK-8389] [STREAMING] [PYSPARK] Expose KafkaRDDs offsetRange in Python This PR propose a simple way to expose OffsetRange in Python code, also the usage of offsetRanges is similar to Scala/Java way, here in Python we could get OffsetRange like: ``` dstream.foreachRDD(lambda r: KafkaUtils.offsetRanges(r)) ``` Reason I didn't follow the way what SPARK-8389 suggested is that: Python Kafka API has one more step to decode the message compared to Scala/Java, Which makes Python API return a transformed RDD/DStream, not directly wrapped so-called JavaKafkaRDD, so it is hard to backtrack to the original RDD to get the offsetRange. Author: jerryshao Closes #7185 from jerryshao/SPARK-8389 and squashes the following commits: 4c6d320 [jerryshao] Another way to fix subclass deserialization issue e6a8011 [jerryshao] Address the comments fd13937 [jerryshao] Fix serialization bug 7debf1c [jerryshao] bug fix cff3893 [jerryshao] refactor the code according to the comments 2aabf9e [jerryshao] Style fix 848c708 [jerryshao] Add HasOffsetRanges for Python --- .../spark/streaming/kafka/KafkaUtils.scala | 13 ++ python/pyspark/streaming/kafka.py | 123 ++++++++++++++++-- python/pyspark/streaming/tests.py | 64 +++++++++ python/pyspark/streaming/util.py | 7 +- 4 files changed, 196 insertions(+), 11 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0e33362d34acd..f3b01bd60b178 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -670,4 +670,17 @@ private class KafkaUtilsPythonHelper { TopicAndPartition(topic, partition) def createBroker(host: String, port: JInt): Broker = Broker(host, port) + + def offsetRangesOfKafkaRDD(rdd: RDD[_]): JList[OffsetRange] = { + val parentRDDs = rdd.getNarrowAncestors + val kafkaRDDs = parentRDDs.filter(rdd => rdd.isInstanceOf[KafkaRDD[_, _, _, _, _]]) + + require( + kafkaRDDs.length == 1, + "Cannot get offset ranges, as there may be multiple Kafka RDDs or no Kafka RDD associated" + + "with this RDD, please call this method only on a Kafka RDD.") + + val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] + kafkaRDD.offsetRanges.toSeq + } } diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 10a859a532e28..33dd596335b47 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -21,6 +21,8 @@ from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream +from pyspark.streaming.dstream import TransformedDStream +from pyspark.streaming.util import TransformFunction __all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] @@ -122,8 +124,9 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + stream = DStream(jstream, ssc, ser) \ + .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @staticmethod def createRDD(sc, kafkaParams, offsetRanges, leaders={}, @@ -161,8 +164,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - rdd = RDD(jrdd, sc, ser) - return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer) @staticmethod def _printErrorMsg(sc): @@ -200,14 +203,30 @@ def __init__(self, topic, partition, fromOffset, untilOffset): :param fromOffset: Inclusive starting offset. :param untilOffset: Exclusive ending offset. """ - self._topic = topic - self._partition = partition - self._fromOffset = fromOffset - self._untilOffset = untilOffset + self.topic = topic + self.partition = partition + self.fromOffset = fromOffset + self.untilOffset = untilOffset + + def __eq__(self, other): + if isinstance(other, self.__class__): + return (self.topic == other.topic + and self.partition == other.partition + and self.fromOffset == other.fromOffset + and self.untilOffset == other.untilOffset) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \ + % (self.topic, self.partition, self.fromOffset, self.untilOffset) def _jOffsetRange(self, helper): - return helper.createOffsetRange(self._topic, self._partition, self._fromOffset, - self._untilOffset) + return helper.createOffsetRange(self.topic, self.partition, self.fromOffset, + self.untilOffset) class TopicAndPartition(object): @@ -244,3 +263,87 @@ def __init__(self, host, port): def _jBroker(self, helper): return helper.createBroker(self._host, self._port) + + +class KafkaRDD(RDD): + """ + A Python wrapper of KafkaRDD, to provide additional information on normal RDD. + """ + + def __init__(self, jrdd, ctx, jrdd_deserializer): + RDD.__init__(self, jrdd, ctx, jrdd_deserializer) + + def offsetRanges(self): + """ + Get the OffsetRange of specific KafkaRDD. + :return: A list of OffsetRange + """ + try: + helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd()) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(self.ctx) + raise e + + ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset()) + for o in joffsetRanges] + return ranges + + +class KafkaDStream(DStream): + """ + A Python wrapper of KafkaDStream + """ + + def __init__(self, jdstream, ssc, jrdd_deserializer): + DStream.__init__(self, jdstream, ssc, jrdd_deserializer) + + def foreachRDD(self, func): + """ + Apply a function to each RDD in this DStream. + """ + if func.__code__.co_argcount == 1: + old_func = func + func = lambda r, rdd: old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \ + .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser)) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) + + def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream. + + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) + """ + if func.__code__.co_argcount == 1: + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.__code__.co_argcount == 2, "func should take one or two arguments" + + return KafkaTransformedDStream(self, func) + + +class KafkaTransformedDStream(TransformedDStream): + """ + Kafka specific wrapper of TransformedDStream to transform on Kafka RDD. + """ + + def __init__(self, prev, func): + TransformedDStream.__init__(self, prev, func) + + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \ + .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser)) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 188c8ff12067e..4ecae1e4bf282 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -678,6 +678,70 @@ def test_kafka_rdd_with_leaders(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_get_offsetRanges(self): + """Test Python direct Kafka RDD get OffsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 3, "b": 4, "c": 5} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self.assertEqual(offsetRanges, rdd.offsetRanges()) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_foreach_get_offsetRanges(self): + """Test the Python direct Kafka stream foreachRDD get offsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + + offsetRanges = [] + + def getOffsetRanges(_, rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + + stream.foreachRDD(getOffsetRanges) + self.ssc.start() + self.wait_for(offsetRanges, 1) + + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_get_offsetRanges(self): + """Test the Python direct Kafka stream transform get offsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.wait_for(offsetRanges, 1) + + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index a9bfec2aab8fc..b20613b1283bd 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,6 +37,11 @@ def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers + self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + + def rdd_wrapper(self, func): + self._rdd_wrapper = func + return self def call(self, milliseconds, jrdds): try: @@ -51,7 +56,7 @@ def call(self, milliseconds, jrdds): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) From c9e2ef52bb54f35a904427389dc492d61f29b018 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 14:43:38 -0700 Subject: [PATCH 231/274] [SPARK-7902] [SPARK-6289] [SPARK-8685] [SQL] [PYSPARK] Refactor of serialization for Python DataFrame This PR fix the long standing issue of serialization between Python RDD and DataFrame, it change to using a customized Pickler for InternalRow to enable customized unpickling (type conversion, especially for UDT), now we can support UDT for UDF, cc mengxr . There is no generated `Row` anymore. Author: Davies Liu Closes #7301 from davies/sql_ser and squashes the following commits: 81bef71 [Davies Liu] address comments e9217bd [Davies Liu] add regression tests db34167 [Davies Liu] Refactor of serialization for Python DataFrame --- python/pyspark/sql/context.py | 5 +- python/pyspark/sql/dataframe.py | 16 +- python/pyspark/sql/tests.py | 28 +- python/pyspark/sql/types.py | 419 ++++++------------ .../spark/sql/catalyst/expressions/rows.scala | 12 + .../org/apache/spark/sql/DataFrame.scala | 5 +- .../spark/sql/execution/pythonUDFs.scala | 122 ++++- 7 files changed, 292 insertions(+), 315 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 309c11faf9319..c93a15badae29 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -30,7 +30,7 @@ from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.sql import since from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ - _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter + _infer_schema, _has_nulltype, _merge_type, _create_converter from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.utils import install_exception_handler @@ -388,8 +388,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): raise TypeError("schema should be StructType or list or None") # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) + rdd = rdd.map(schema.toInternal) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e9c657cf81b3..83e02b85f06f1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -31,7 +31,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql import since -from pyspark.sql.types import _create_cls, _parse_datatype_json_string +from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_java_column from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * @@ -83,15 +83,7 @@ def rdd(self): """ if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() - rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) - schema = self.schema - - def applySchema(it): - cls = _create_cls(schema) - return map(cls, it) - - self._lazy_rdd = rdd.mapPartitions(applySchema) - + self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) return self._lazy_rdd @property @@ -287,9 +279,7 @@ def collect(self): """ with SCCallSiteSync(self._sc) as css: port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) - rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) - cls = _create_cls(self.schema) - return [cls(r) for r in rs] + return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(1.3) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 66827d48850d9..4d7cad5a1ab88 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -151,6 +151,17 @@ def test_range(self): self.assertEqual(self.sqlCtx.range(-2).count(), 0) self.assertEqual(self.sqlCtx.range(3).count(), 3) + def test_duplicated_column_names(self): + df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"]) + row = df.select('*').first() + self.assertEqual(1, row[0]) + self.assertEqual(2, row[1]) + self.assertEqual("Row(c=1, c=2)", str(row)) + # Cannot access columns + self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) + self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) + self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) + def test_explode(self): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] @@ -401,6 +412,14 @@ def test_apply_schema_with_udt(self): point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + def test_udf_with_udt(self): + from pyspark.sql.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sc.parallelize([row]).toDF() + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) @@ -693,12 +712,9 @@ def test_time_with_timezone(self): utcnow = datetime.datetime.fromtimestamp(ts, utc) df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) day1, now1, utcnow1 = df.first() - # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version - self.assertEqual(day1.date(), day) - # Pyrolite does not support microsecond, the error should be - # less than 1 millisecond - self.assertTrue(now - now1 < datetime.timedelta(0.001)) - self.assertTrue(now - utcnow1 < datetime.timedelta(0.001)) + self.assertEqual(day1, day) + self.assertEqual(now, now1) + self.assertEqual(now, utcnow1) def test_decimal(self): from decimal import Decimal diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fecfe6d71e9a7..d63857691675a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -20,13 +20,9 @@ import time import datetime import calendar -import keyword -import warnings import json import re -import weakref from array import array -from operator import itemgetter if sys.version >= "3": long = int @@ -71,6 +67,26 @@ def json(self): separators=(',', ':'), sort_keys=True) + def needConversion(self): + """ + Does this type need to conversion between Python object and internal SQL object. + + This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. + """ + return False + + def toInternal(self, obj): + """ + Converts a Python object into an internal SQL object. + """ + return obj + + def fromInternal(self, obj): + """ + Converts an internal SQL object into a native Python object. + """ + return obj + # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -143,6 +159,17 @@ class DateType(AtomicType): __metaclass__ = DataTypeSingleton + EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + + def needConversion(self): + return True + + def toInternal(self, d): + return d and d.toordinal() - self.EPOCH_ORDINAL + + def fromInternal(self, v): + return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. @@ -150,6 +177,19 @@ class TimestampType(AtomicType): __metaclass__ = DataTypeSingleton + def needConversion(self): + return True + + def toInternal(self, dt): + if dt is not None: + seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo + else time.mktime(dt.timetuple())) + return int(seconds * 1e6 + dt.microsecond) + + def fromInternal(self, ts): + if ts is not None: + return datetime.datetime.fromtimestamp(ts / 1e6) + class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. @@ -259,6 +299,19 @@ def fromJson(cls, json): return ArrayType(_parse_datatype_json_value(json["elementType"]), json["containsNull"]) + def needConversion(self): + return self.elementType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.toInternal(v) for v in obj] + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.fromInternal(v) for v in obj] + class MapType(DataType): """Map data type. @@ -304,6 +357,21 @@ def fromJson(cls, json): _parse_datatype_json_value(json["valueType"]), json["valueContainsNull"]) + def needConversion(self): + return self.keyType.needConversion() or self.valueType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) + for k, v in obj.items()) + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) + for k, v in obj.items()) + class StructField(DataType): """A field in :class:`StructType`. @@ -311,7 +379,7 @@ class StructField(DataType): :param name: string, name of the field. :param dataType: :class:`DataType` of the field. :param nullable: boolean, whether the field can be null (None) or not. - :param metadata: a dict from string to simple type that can be serialized to JSON automatically + :param metadata: a dict from string to simple type that can be toInternald to JSON automatically """ def __init__(self, name, dataType, nullable=True, metadata=None): @@ -351,6 +419,15 @@ def fromJson(cls, json): json["nullable"], json["metadata"]) + def needConversion(self): + return self.dataType.needConversion() + + def toInternal(self, obj): + return self.dataType.toInternal(obj) + + def fromInternal(self, obj): + return self.dataType.fromInternal(obj) + class StructType(DataType): """Struct type, consisting of a list of :class:`StructField`. @@ -371,10 +448,13 @@ def __init__(self, fields=None): """ if not fields: self.fields = [] + self.names = [] else: self.fields = fields + self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" + self._needSerializeFields = None def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -406,6 +486,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): """ if isinstance(field, StructField): self.fields.append(field) + self.names.append(field.name) else: if isinstance(field, str) and data_type is None: raise ValueError("Must specify DataType if passing name of struct_field to create.") @@ -415,6 +496,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): else: data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) + self.names.append(field) return self def simpleString(self): @@ -432,6 +514,41 @@ def jsonValue(self): def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) + def needConversion(self): + # We need convert Row()/namedtuple into tuple() + return True + + def toInternal(self, obj): + if obj is None: + return + + if self._needSerializeFields is None: + self._needSerializeFields = any(f.needConversion() for f in self.fields) + + if self._needSerializeFields: + if isinstance(obj, dict): + return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields)) + elif isinstance(obj, (tuple, list)): + return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + else: + if isinstance(obj, dict): + return tuple(obj.get(n) for n in self.names) + elif isinstance(obj, (list, tuple)): + return tuple(obj) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + + def fromInternal(self, obj): + if obj is None: + return + if isinstance(obj, Row): + # it's already converted by pickler + return obj + values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)] + return _create_row(self.names, values) + class UserDefinedType(DataType): """User-defined type (UDT). @@ -464,17 +581,35 @@ def scalaUDT(cls): """ raise NotImplementedError("UDT must have a paired Scala UDT.") + def needConversion(self): + return True + + @classmethod + def _cachedSqlType(cls): + """ + Cache the sqlType() into class, because it's heavy used in `toInternal`. + """ + if not hasattr(cls, "_cached_sql_type"): + cls._cached_sql_type = cls.sqlType() + return cls._cached_sql_type + + def toInternal(self, obj): + return self._cachedSqlType().toInternal(self.serialize(obj)) + + def fromInternal(self, obj): + return self.deserialize(self._cachedSqlType().fromInternal(obj)) + def serialize(self, obj): """ Converts the a user-type object into a SQL datum. """ - raise NotImplementedError("UDT must implement serialize().") + raise NotImplementedError("UDT must implement toInternal().") def deserialize(self, datum): """ Converts a SQL datum into a user-type object. """ - raise NotImplementedError("UDT must implement deserialize().") + raise NotImplementedError("UDT must implement fromInternal().") def simpleString(self): return 'udt' @@ -671,117 +806,6 @@ def _infer_schema(row): return StructType(fields) -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - True - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - # convert namedtuple or Row into tuple - return True - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - elif isinstance(dataType, (DateType, TimestampType)): - return True - else: - return False - - -EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - if any(_need_python_to_sql_conversion(t) for t in types): - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - elif obj is not None: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - else: - def converter(obj): - if isinstance(obj, dict): - return tuple(obj.get(n) for n in names) - else: - return tuple(obj) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: a and [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - - elif isinstance(dataType, UserDefinedType): - return lambda obj: obj and dataType.serialize(obj) - - elif isinstance(dataType, DateType): - return lambda d: d and d.toordinal() - EPOCH_ORDINAL - - elif isinstance(dataType, TimestampType): - - def to_posix_timstamp(dt): - if dt: - seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo - else time.mktime(dt.timetuple())) - return int(seconds * 1e6 + dt.microsecond) - return to_posix_timstamp - - else: - raise ValueError("Unexpected type %r" % dataType) - - def _has_nulltype(dt): """ Return whether there is NullType in `dt` or not """ if isinstance(dt, StructType): @@ -1076,7 +1100,7 @@ def _verify_type(obj, dataType): if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) + _verify_type(dataType.toInternal(obj), dataType.sqlType()) return _type = type(dataType) @@ -1086,7 +1110,7 @@ def _verify_type(obj, dataType): if not isinstance(obj, (tuple, list)): raise TypeError("StructType can not accept object in type %s" % type(obj)) else: - # subclass of them can not be deserialized in JVM + # subclass of them can not be fromInternald in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) @@ -1106,159 +1130,10 @@ def _verify_type(obj, dataType): for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) -_cached_cls = weakref.WeakValueDictionary() - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None or cls.__datatype is not dataType: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - cls.__datatype = dataType - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for atomic types - return lambda x: x - - class Row(tuple): - - """ Row in DataFrame """ - __datatype = dataType - __fields__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__fields__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__fields__)) - - def __reduce__(self): - return (_restore_object, (self.__datatype, tuple(self))) - return Row +# This is used to unpickle a Row from JVM +def _create_row_inbound_converter(dataType): + return lambda *a: dataType.fromInternal(a) def _create_row(fields, values): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8b472a529e5c9..094904bbf9c15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -132,6 +132,18 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) override def copy(): InternalRow = this } +/** + * This is used for serialization of Python DataFrame + */ +class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType) + extends GenericInternalRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) + + override def fieldIndex(name: String): Int = schema.fieldIndex(name) +} + class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { /** No-arg constructor for serialization. */ protected def this() = this(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d9f987ae0252f..d7966651b1948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -30,7 +30,6 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -1550,8 +1549,8 @@ class DataFrame private[sql]( */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { val structType = schema // capture it for closure - val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) + val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) + EvaluatePython.javaToPython(rdd) } //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 1c8130b07c7fb..6d6e67dace177 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.execution +import java.io.OutputStream import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import net.razorvine.pickle.{Pickler, Unpickler} +import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} +import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Accumulator, Logging => SparkLogging} @@ -130,8 +130,13 @@ object EvaluatePython { case (null, _) => null case (row: InternalRow, struct: StructType) => - val fields = struct.fields.map(field => field.dataType) - rowToArray(row, fields) + val values = new Array[Any](row.size) + var i = 0 + while (i < row.size) { + values(i) = toJava(row(i), struct.fields(i).dataType) + i += 1 + } + new GenericInternalRowWithSchema(values, struct) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava @@ -142,9 +147,6 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) - case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) - case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) - case (d: Decimal, _) => d.toJavaBigDecimal case (s: UTF8String, StringType) => s.toString @@ -152,14 +154,6 @@ object EvaluatePython { case (other, _) => other } - /** - * Convert Row into Java Array (for pickled into Python) - */ - def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = { - // TODO: this is slow! - row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray - } - /** * Converts `obj` to the type specified by the data type, or returns null if the type of obj is * unexpected. Because Python doesn't enforce the type. @@ -220,6 +214,96 @@ object EvaluatePython { // TODO(davies): we could improve this by try to cast the object to expected type case (c, _) => null } + + + private val module = "pyspark.sql.types" + + /** + * Pickler for StructType + */ + private class StructTypePickler extends IObjectPickler { + + private val cls = classOf[StructType] + + def register(): Unit = { + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_parse_datatype_json_string" + "\n").getBytes("utf-8")) + val schema = obj.asInstanceOf[StructType] + pickler.save(schema.json) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + } + } + + /** + * Pickler for InternalRow + */ + private class RowPickler extends IObjectPickler { + + private val cls = classOf[GenericInternalRowWithSchema] + + // register this to Pickler and Unpickler + def register(): Unit = { + Pickler.registerCustomPickler(this.getClass, this) + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_create_row_inbound_converter" + "\n").getBytes("utf-8")) + } else { + // it will be memorized by Pickler to save some bytes + pickler.save(this) + val row = obj.asInstanceOf[GenericInternalRowWithSchema] + // schema should always be same object for memoization + pickler.save(row.schema) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + + out.write(Opcodes.MARK) + var i = 0 + while (i < row.values.size) { + pickler.save(row.values(i)) + i += 1 + } + row.values.foreach(pickler.save) + out.write(Opcodes.TUPLE) + out.write(Opcodes.REDUCE) + } + } + } + + private[this] var registered = false + /** + * This should be called before trying to serialize any above classes un cluster mode, + * this should be put in the closure + */ + def registerPicklers(): Unit = { + synchronized { + if (!registered) { + SerDeUtil.initialize() + new StructTypePickler().register() + new RowPickler().register() + registered = true + } + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { + rdd.mapPartitions { iter => + registerPicklers() // let it called in executor + new SerDeUtil.AutoBatchedPickler(iter) + } + } } /** @@ -254,12 +338,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val childResults = child.execute().map(_.copy()) val parent = childResults.mapPartitions { iter => + EvaluatePython.registerPicklers() // register pickler for Row val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) - iter.grouped(1000).map { inputRows => + val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => - EvaluatePython.rowToArray(currentRow(row), fields) + EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } From 897700369f3aedf1a8fdb0984dd3d6d8e498e3af Mon Sep 17 00:00:00 2001 From: guowei2 Date: Thu, 9 Jul 2015 15:01:53 -0700 Subject: [PATCH 232/274] [SPARK-8865] [STREAMING] FIX BUG: check key in kafka params Author: guowei2 Closes #7254 from guowei2/spark-8865 and squashes the following commits: 48ca17a [guowei2] fix contains key --- .../scala/org/apache/spark/streaming/kafka/KafkaCluster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 3e6b937af57b0..8465432c5850f 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -410,7 +410,7 @@ object KafkaCluster { } Seq("zookeeper.connect", "group.id").foreach { s => - if (!props.contains(s)) { + if (!props.containsKey(s)) { props.setProperty(s, "") } } From 69165330303a71ea1da748eca7a780ec172b326f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 9 Jul 2015 15:14:14 -0700 Subject: [PATCH 233/274] Closes #6837 Closes #7321 Closes #2634 Closes #4963 Closes #2137 From e29ce319fa6ffb9c8e5110814d4923d433aa1b76 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 9 Jul 2015 15:49:30 -0700 Subject: [PATCH 234/274] [SPARK-8963][ML] cleanup tests in linear regression suite Simplify model weight assertions to use vector comparision, switch to using absTol when comparing with 0.0 intercepts Author: Holden Karau Closes #7327 from holdenk/SPARK-8913-cleanup-tests-from-SPARK-8700-logistic-regression and squashes the following commits: 5bac185 [Holden Karau] Simplify model weight assertions to use vector comparision, switch to using absTol when comparing with 0.0 intercepts --- .../ml/regression/LinearRegressionSuite.scala | 57 ++++++++----------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 5f39d44f37352..4f6a57739558b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -75,11 +75,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 7.198257 */ val interceptR = 6.298698 - val weightsR = Array(4.700706, 7.199082) + val weightsR = Vectors.dense(4.700706, 7.199082) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -104,11 +103,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V2. 6.995908 as.numeric.data.V3. 5.275131 */ - val weightsR = Array(6.995908, 5.275131) + val weightsR = Vectors.dense(6.995908, 5.275131) - assert(model.intercept ~== 0 relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== 0 absTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) /* Then again with the data with no intercept: > weightsWithoutIntercept @@ -118,11 +116,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data3.V2. 4.70011 as.numeric.data3.V3. 7.19943 */ - val weightsWithoutInterceptR = Array(4.70011, 7.19943) + val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) - assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3) - assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3) - assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3) + assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3) + assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3) } test("linear regression with intercept with L1 regularization") { @@ -139,11 +136,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 6.679841 */ val interceptR = 6.24300 - val weightsR = Array(4.024821, 6.679841) + val weightsR = Vectors.dense(4.024821, 6.679841) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -169,11 +165,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.772913 */ val interceptR = 0.0 - val weightsR = Array(6.299752, 4.772913) + val weightsR = Vectors.dense(6.299752, 4.772913) - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-5) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -197,11 +192,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.926260 */ val interceptR = 5.269376 - val weightsR = Array(3.736216, 5.712356) + val weightsR = Vectors.dense(3.736216, 5.712356) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -227,11 +221,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.214502 */ val interceptR = 0.0 - val weightsR = Array(5.522875, 4.214502) + val weightsR = Vectors.dense(5.522875, 4.214502) - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + assert(model.weights ~== weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -255,11 +248,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 5.200403 */ val interceptR = 5.696056 - val weightsR = Array(3.670489, 6.001122) + val weightsR = Vectors.dense(3.670489, 6.001122) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~== weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -285,11 +277,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.dataM.V3. 4.322251 */ val interceptR = 0.0 - val weightsR = Array(5.673348, 4.322251) + val weightsR = Vectors.dense(5.673348, 4.322251) - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => From a0cc3e5aa3fcfd0fce6813c520152657d327aaf2 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 9 Jul 2015 16:21:21 -0700 Subject: [PATCH 235/274] [SPARK-8538] [SPARK-8539] [ML] Linear Regression Training and Testing Results Adds results (e.g. objective value at each iteration, residuals) on training and user-specified test sets for LinearRegressionModel. Notes to Reviewers: * Are the `*TrainingResults` and `Results` classes too specialized for `LinearRegressionModel`? Where would be an appropriate level of abstraction? * Please check `transient` annotations are correct; the datasets should not be copied and kept during serialization. * Any thoughts on `RDD`s versus `DataFrame`s? If using `DataFrame`s, suggested schemas for each intermediate step? Also, how to create a "local DataFrame" without a `sqlContext`? Author: Feynman Liang Closes #7099 from feynmanliang/SPARK-8538 and squashes the following commits: d219fa4 [Feynman Liang] Update docs 4a42680 [Feynman Liang] Change Summary to hold values, move transient annotations down to metrics and predictions DF 6300031 [Feynman Liang] Code review changes 0a5e762 [Feynman Liang] Fix build error e71102d [Feynman Liang] Merge branch 'master' into SPARK-8538 3367489 [Feynman Liang] Merge branch 'master' into SPARK-8538 70f267c [Feynman Liang] Make TrainingSummary transient and remove Serializable from *Summary and RegressionMetrics 1d9ea42 [Feynman Liang] Fix failing Java test a65dfda [Feynman Liang] Make TrainingSummary and metrics serializable, prediction dataframe transient 0a605d8 [Feynman Liang] Replace Params from LinearRegression*Summary with private constructor vals c2fe835 [Feynman Liang] Optimize imports 02d8a70 [Feynman Liang] Add Params to LinearModel*Summary, refactor tests and add test for evaluate() 8f999f4 [Feynman Liang] Refactor from jkbradley code review 072e948 [Feynman Liang] Style 509ae36 [Feynman Liang] Use DFs and localize serialization to LinearRegressionModel 9509c79 [Feynman Liang] Fix imports b2bbaa3 [Feynman Liang] Refactored LinearRegressionResults API to be more private ffceaec [Feynman Liang] Merge branch 'master' into SPARK-8538 1cedb2b [Feynman Liang] Add test for decreasing objective trace dab0aff [Feynman Liang] Add LinearRegressionTrainingResults tests, make test suite code copy+pasteable 97b0a81 [Feynman Liang] Add LinearRegressionModel.evaluate() to get results on test sets dc51bce [Feynman Liang] Style guide fixes 521f397 [Feynman Liang] Use RDD[(Double, Double)] instead of DF 2ff5710 [Feynman Liang] Add training results and model summary to ML LinearRegression --- .../ml/regression/LinearRegression.scala | 139 +++++++++++++++++- .../ml/regression/LinearRegressionSuite.scala | 59 ++++++++ 2 files changed, 192 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f672c96576a33..8fc986056657d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,18 +22,20 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV, norm => brzNorm} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter @@ -139,7 +141,16 @@ class LinearRegression(override val uid: String) logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + s"and the intercept will be the mean of the label; as a result, training is not needed.") if (handlePersistence) instances.unpersist() - return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean) + val weights = Vectors.sparse(numFeatures, Seq()) + val intercept = yMean + + val model = new LinearRegressionModel(uid, weights, intercept) + val trainingSummary = new LinearRegressionTrainingSummary( + model.transform(dataset).select($(predictionCol), $(labelCol)), + $(predictionCol), + $(labelCol), + Array(0D)) + return copyValues(model.setSummary(trainingSummary)) } val featuresMean = summarizer.mean.toArray @@ -178,7 +189,6 @@ class LinearRegression(override val uid: String) state = states.next() arrayBuilder += state.adjustedValue } - if (state == null) { val msg = s"${optimizer.getClass.getName} failed." logError(msg) @@ -209,7 +219,13 @@ class LinearRegression(override val uid: String) if (handlePersistence) instances.unpersist() - copyValues(new LinearRegressionModel(uid, weights, intercept)) + val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) + val trainingSummary = new LinearRegressionTrainingSummary( + model.transform(dataset).select($(predictionCol), $(labelCol)), + $(predictionCol), + $(labelCol), + objectiveHistory) + model.setSummary(trainingSummary) } override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) @@ -227,13 +243,124 @@ class LinearRegressionModel private[ml] ( extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams { + private var trainingSummary: Option[LinearRegressionTrainingSummary] = None + + /** + * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + def summary: LinearRegressionTrainingSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this LinearRegressionModel", + new NullPointerException()) + } + + private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** Indicates whether a training summary exists for this model instance. */ + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Evaluates the model on a testset. + * @param dataset Test dataset to evaluate model on. + */ + // TODO: decide on a good name before exposing to public API + private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + val t = udf { features: Vector => predict(features) } + val predictionAndObservations = dataset + .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol))) + + new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol)) + } + override protected def predict(features: Vector): Double = { dot(features, weights) + intercept } override def copy(extra: ParamMap): LinearRegressionModel = { - copyValues(new LinearRegressionModel(uid, weights, intercept), extra) + val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) + if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) + newModel + } +} + +/** + * :: Experimental :: + * Linear regression training results. + * @param predictions predictions outputted by the model's `transform` method. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +@Experimental +class LinearRegressionTrainingSummary private[regression] ( + predictions: DataFrame, + predictionCol: String, + labelCol: String, + val objectiveHistory: Array[Double]) + extends LinearRegressionSummary(predictions, predictionCol, labelCol) { + + /** Number of training iterations until termination */ + val totalIterations = objectiveHistory.length + +} + +/** + * :: Experimental :: + * Linear regression results evaluated on a dataset. + * @param predictions predictions outputted by the model's `transform` method. + */ +@Experimental +class LinearRegressionSummary private[regression] ( + @transient val predictions: DataFrame, + val predictionCol: String, + val labelCol: String) extends Serializable { + + @transient private val metrics = new RegressionMetrics( + predictions + .select(predictionCol, labelCol) + .map { case Row(pred: Double, label: Double) => (pred, label) } ) + + /** + * Returns the explained variance regression score. + * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + */ + val explainedVariance: Double = metrics.explainedVariance + + /** + * Returns the mean absolute error, which is a risk function corresponding to the + * expected value of the absolute error loss or l1-norm loss. + */ + val meanAbsoluteError: Double = metrics.meanAbsoluteError + + /** + * Returns the mean squared error, which is a risk function corresponding to the + * expected value of the squared error loss or quadratic loss. + */ + val meanSquaredError: Double = metrics.meanSquaredError + + /** + * Returns the root mean squared error, which is defined as the square root of + * the mean squared error. + */ + val rootMeanSquaredError: Double = metrics.rootMeanSquaredError + + /** + * Returns R^2^, the coefficient of determination. + * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + */ + val r2: Double = metrics.r2 + + /** Residuals (predicted value - label value) */ + @transient lazy val residuals: DataFrame = { + val t = udf { (pred: Double, label: Double) => pred - label} + predictions.select(t(col(predictionCol), col(labelCol)).as("residuals")) } + } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 4f6a57739558b..cf120cf2a4b47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -289,4 +289,63 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(prediction1 ~== prediction2 relTol 1E-5) } } + + test("linear regression model training summary") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + // Training results for the model should be available + assert(model.hasSummary) + + // Residuals in [[LinearRegressionResults]] should equal those manually computed + val expectedResiduals = dataset.select("features", "label") + .map { case Row(features: DenseVector, label: Double) => + val prediction = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + prediction - label + } + .zip(model.summary.residuals.map(_.getDouble(0))) + .collect() + .foreach { case (manualResidual: Double, resultResidual: Double) => + assert(manualResidual ~== resultResidual relTol 1E-5) + } + + /* + Use the following R code to generate model training results. + + predictions <- predict(fit, newx=features) + residuals <- predictions - label + > mean(residuals^2) # MSE + [1] 0.009720325 + > mean(abs(residuals)) # MAD + [1] 0.07863206 + > cor(predictions, label)^2# r^2 + [,1] + s0 0.9998749 + */ + assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) + assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) + assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) + + // Objective function should be monotonically decreasing for linear regression + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } + + test("linear regression model testset evaluation summary") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + // Evaluating on training dataset should yield results summary equal to training summary + val testSummary = model.evaluate(dataset) + assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5) + assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5) + model.summary.residuals.select("residuals").collect() + .zip(testSummary.residuals.select("residuals").collect()) + .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } + } + } From 2d45571fcb002cc9f03056c5a3f14493b83315a4 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 9 Jul 2015 17:09:16 -0700 Subject: [PATCH 236/274] [SPARK-8959] [SQL] [HOTFIX] Removes parquet-thrift and libthrift dependencies These two dependencies were introduced in #7231 to help testing Parquet compatibility with `parquet-thrift`. However, they somehow crash the Scala compiler in Maven builds. This PR fixes this issue by: 1. Removing these two dependencies, and 2. Instead of generating the testing Parquet file programmatically, checking in an actual testing Parquet file generated by `parquet-thrift` as a test resource. This is just a quick fix to bring back Maven builds. Need to figure out the root case as binary Parquet files are harder to maintain. Author: Cheng Lian Closes #7330 from liancheng/spark-8959 and squashes the following commits: cf69512 [Cheng Lian] Brings back Maven builds --- pom.xml | 14 - sql/core/pom.xml | 10 - .../spark/sql/parquet/test/thrift/Nested.java | 541 ---- .../test/thrift/ParquetThriftCompat.java | 2808 ----------------- .../spark/sql/parquet/test/thrift/Suit.java | 51 - .../parquet-thrift-compat.snappy.parquet | Bin 0 -> 10550 bytes .../ParquetThriftCompatibilitySuite.scala | 78 +- 7 files changed, 8 insertions(+), 3494 deletions(-) delete mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Nested.java delete mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/ParquetThriftCompat.java delete mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Suit.java create mode 100755 sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet diff --git a/pom.xml b/pom.xml index 529e47f8b5253..1eda108dc065b 100644 --- a/pom.xml +++ b/pom.xml @@ -161,7 +161,6 @@ 2.4.4 1.1.1.7 1.1.2 - 0.9.2 false @@ -181,7 +180,6 @@ compile compile test - test + + commons-codec + commons-codec + provided + + + commons-net + commons-net + provided + + + com.google.protobuf + protobuf-java + provided + org.apache.avro avro - ${avro.version} + provided org.apache.avro avro-ipc - ${avro.version} - - - io.netty - netty - - - org.mortbay.jetty - jetty - - - org.mortbay.jetty - jetty-util - - - org.mortbay.jetty - servlet-api - - - org.apache.velocity - velocity - - + provided + + + org.scala-lang + scala-library + provided - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - reference.conf - - - log4j.properties - - - - - - - - - - + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + + + flume-provided + + provided + + + diff --git a/pom.xml b/pom.xml index 1eda108dc065b..172fdef4c73da 100644 --- a/pom.xml +++ b/pom.xml @@ -1130,6 +1130,10 @@ io.netty netty + + org.apache.flume + flume-ng-auth + org.apache.thrift libthrift From 2727304660663fcf1e41f7b666978c1443262e4e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 9 Jul 2015 19:08:33 -0700 Subject: [PATCH 238/274] [SPARK-8913] [ML] Simplify LogisticRegression suite to use Vector Vector comparision Cleanup tests from SPARK 8700. Author: Holden Karau Closes #7335 from holdenk/SPARK-8913-cleanup-tests-from-SPARK-8700-logistic-regression-r2-really-logistic-regression-this-time and squashes the following commits: e5e2c5f [Holden Karau] Simplify LogisticRegression suite to use Vector <-> Vector comparisions instead of comparing element by element --- .../LogisticRegressionSuite.scala | 135 +++++------------- 1 file changed, 39 insertions(+), 96 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 27253c1db2fff..b7dd44753896a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -234,20 +234,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.7996864 */ val interceptR = 2.8366423 - val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864) + val weightsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model1.weights(1) ~== weightsR(1) relTol 1E-3) - assert(model1.weights(2) ~== weightsR(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR(3) relTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-3) // Without regularization, with or without standardization will converge to the same solution. assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model2.weights(1) ~== weightsR(1) relTol 1E-3) - assert(model2.weights(2) ~== weightsR(2) relTol 1E-3) - assert(model2.weights(3) ~== weightsR(3) relTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-3) } test("binary logistic regression without intercept without regularization") { @@ -277,20 +271,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.7407946 */ val interceptR = 0.0 - val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946) + val weightsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights(0) ~== weightsR(0) relTol 1E-2) - assert(model1.weights(1) ~== weightsR(1) relTol 1E-2) - assert(model1.weights(2) ~== weightsR(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR(3) relTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-2) // Without regularization, with or without standardization should converge to the same solution. assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights(0) ~== weightsR(0) relTol 1E-2) - assert(model2.weights(1) ~== weightsR(1) relTol 1E-2) - assert(model2.weights(2) ~== weightsR(2) relTol 1E-3) - assert(model2.weights(3) ~== weightsR(3) relTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-2) } test("binary logistic regression with intercept with L1 regularization") { @@ -321,13 +309,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.02481551 */ val interceptR1 = -0.05627428 - val weightsR1 = Array(0.0, 0.0, -0.04325749, -0.02481551) + val weightsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.weights(0) ~== weightsR1(0) absTol 1E-3) - assert(model1.weights(1) ~== weightsR1(1) absTol 1E-3) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-2) - assert(model1.weights(3) ~== weightsR1(3) relTol 2E-2) + assert(model1.weights ~= weightsR1 absTol 2E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -349,13 +334,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.3722152 - val weightsR2 = Array(0.0, 0.0, -0.1665453, 0.0) + val weightsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) assert(model2.intercept ~== interceptR2 relTol 1E-2) - assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) absTol 1E-3) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-2) - assert(model2.weights(3) ~== weightsR2(3) absTol 1E-3) + assert(model2.weights ~= weightsR2 absTol 1E-3) } test("binary logistic regression without intercept with L1 regularization") { @@ -387,13 +369,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.03891782 */ val interceptR1 = 0.0 - val weightsR1 = Array(0.0, 0.0, -0.05189203, -0.03891782) + val weightsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights(0) ~== weightsR1(0) absTol 1E-3) - assert(model1.weights(1) ~== weightsR1(1) absTol 1E-3) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-2) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-2) + assert(model1.weights ~= weightsR1 absTol 1E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -415,13 +394,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.0 - val weightsR2 = Array(0.0, 0.0, -0.08420782, 0.0) + val weightsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) absTol 1E-3) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-2) - assert(model2.weights(3) ~== weightsR2(3) absTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 absTol 1E-3) } test("binary logistic regression with intercept with L2 regularization") { @@ -452,13 +428,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.10062872 */ val interceptR1 = 0.15021751 - val weightsR1 = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872) + val weightsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights(0) ~== weightsR1(0) relTol 1E-3) - assert(model1.weights(1) ~== weightsR1(1) relTol 1E-3) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -480,13 +453,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.06266838 */ val interceptR2 = 0.48657516 - val weightsR2 = Array(-0.05155371, 0.02301057, -0.11482896, -0.06266838) + val weightsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights(0) ~== weightsR2(0) relTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) relTol 1E-3) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-3) - assert(model2.weights(3) ~== weightsR2(3) relTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) } test("binary logistic regression without intercept with L2 regularization") { @@ -518,13 +488,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.09799775 */ val interceptR1 = 0.0 - val weightsR1 = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775) + val weightsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights(0) ~== weightsR1(0) relTol 1E-2) - assert(model1.weights(1) ~== weightsR1(1) relTol 1E-2) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -546,13 +513,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.053314311 */ val interceptR2 = 0.0 - val weightsR2 = Array(-0.005679651, 0.048967094, -0.093714016, -0.053314311) + val weightsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights(0) ~== weightsR2(0) relTol 1E-2) - assert(model2.weights(1) ~== weightsR2(1) relTol 1E-2) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-3) - assert(model2.weights(3) ~== weightsR2(3) relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-2) } test("binary logistic regression with intercept with ElasticNet regularization") { @@ -583,13 +547,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.15458796 */ val interceptR1 = 0.57734851 - val weightsR1 = Array(-0.05310287, 0.0, -0.08849250, -0.15458796) + val weightsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796) assert(model1.intercept ~== interceptR1 relTol 6E-3) - assert(model1.weights(0) ~== weightsR1(0) relTol 5E-3) - assert(model1.weights(1) ~== weightsR1(1) absTol 1E-3) - assert(model1.weights(2) ~== weightsR1(2) relTol 5E-3) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-3) + assert(model1.weights ~== weightsR1 absTol 5E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -611,13 +572,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.05350074 */ val interceptR2 = 0.51555993 - val weightsR2 = Array(0.0, 0.0, -0.18807395, -0.05350074) + val weightsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074) assert(model2.intercept ~== interceptR2 relTol 6E-3) - assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) absTol 1E-3) - assert(model2.weights(2) ~== weightsR2(2) relTol 5E-3) - assert(model2.weights(3) ~== weightsR2(3) relTol 1E-2) + assert(model2.weights ~= weightsR2 absTol 1E-3) } test("binary logistic regression without intercept with ElasticNet regularization") { @@ -649,13 +607,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.142534158 */ val interceptR1 = 0.0 - val weightsR1 = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158) + val weightsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights(0) ~== weightsR1(0) absTol 1E-2) - assert(model1.weights(1) ~== weightsR1(1) absTol 1E-2) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-2) + assert(model1.weights ~= weightsR1 absTol 1E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -677,13 +632,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.0 - val weightsR2 = Array(0.0, 0.03345223, -0.11304532, 0.0) + val weightsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) relTol 1E-2) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-2) - assert(model2.weights(3) ~== weightsR2(3) absTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 absTol 1E-3) } test("binary logistic regression with intercept with strong L1 regularization") { @@ -717,19 +669,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { }}} */ val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) - val weightsTheory = Array(0.0, 0.0, 0.0, 0.0) + val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptTheory relTol 1E-5) - assert(model1.weights(0) ~== weightsTheory(0) absTol 1E-6) - assert(model1.weights(1) ~== weightsTheory(1) absTol 1E-6) - assert(model1.weights(2) ~== weightsTheory(2) absTol 1E-6) - assert(model1.weights(3) ~== weightsTheory(3) absTol 1E-6) + assert(model1.weights ~= weightsTheory absTol 1E-6) assert(model2.intercept ~== interceptTheory relTol 1E-5) - assert(model2.weights(0) ~== weightsTheory(0) absTol 1E-6) - assert(model2.weights(1) ~== weightsTheory(1) absTol 1E-6) - assert(model2.weights(2) ~== weightsTheory(2) absTol 1E-6) - assert(model2.weights(3) ~== weightsTheory(3) absTol 1E-6) + assert(model2.weights ~= weightsTheory absTol 1E-6) /* Using the following R code to load the data and train the model using glmnet package. @@ -750,12 +696,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR = -0.248065 - val weightsR = Array(0.0, 0.0, 0.0, 0.0) + val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptR relTol 1E-5) - assert(model1.weights(0) ~== weightsR(0) absTol 1E-6) - assert(model1.weights(1) ~== weightsR(1) absTol 1E-6) - assert(model1.weights(2) ~== weightsR(2) absTol 1E-6) - assert(model1.weights(3) ~== weightsR(3) absTol 1E-6) + assert(model1.weights ~= weightsR absTol 1E-6) } } From 1903641e68ce7e7e657584bf45e91db6df357e41 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Thu, 9 Jul 2015 19:31:31 -0700 Subject: [PATCH 239/274] [SPARK-8839] [SQL] ThriftServer2 will remove session and execution no matter it's finished or not. In my test, `sessions` and `executions` in ThriftServer2 is not the same number as the connection number. For example, if there are 200 clients connecting to the server, but it will have more than 200 `sessions` and `executions`. So if it reaches the `retainedStatements`, it has to remove some object which is not finished. So it may cause the exception described in [Jira Address](https://issues.apache.org/jira/browse/SPARK-8839) Author: huangzhaowei Closes #7239 from SaintBacchus/SPARK-8839 and squashes the following commits: cf7ef40 [huangzhaowei] Remove the a meanless funciton call 3e9a5a6 [huangzhaowei] Add a filter before take 9d5ceb8 [huangzhaowei] [SPARK-8839][SQL]ThriftServer2 will remove session and execution no matter it's finished or not. --- .../spark/sql/hive/thriftserver/HiveThriftServer2.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 700d994bb6a83..b7db80d93f852 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -179,6 +179,7 @@ object HiveThriftServer2 extends Logging { def onSessionClosed(sessionId: String): Unit = { sessionList(sessionId).finishTimestamp = System.currentTimeMillis onlineSessionNum -= 1 + trimSessionIfNecessary() } def onStatementStart( @@ -206,18 +207,20 @@ object HiveThriftServer2 extends Logging { executionList(id).detail = errorMessage executionList(id).state = ExecutionState.FAILED totalRunning -= 1 + trimExecutionIfNecessary() } def onStatementFinish(id: String): Unit = { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.FINISHED totalRunning -= 1 + trimExecutionIfNecessary() } private def trimExecutionIfNecessary() = synchronized { if (executionList.size > retainedStatements) { val toRemove = math.max(retainedStatements / 10, 1) - executionList.take(toRemove).foreach { s => + executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => executionList.remove(s._1) } } @@ -226,7 +229,7 @@ object HiveThriftServer2 extends Logging { private def trimSessionIfNecessary() = synchronized { if (sessionList.size > retainedSessions) { val toRemove = math.max(retainedSessions / 10, 1) - sessionList.take(toRemove).foreach { s => + sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => sessionList.remove(s._1) } } From d538919cc4fd3ab940d478c62dce1bae0270cfeb Mon Sep 17 00:00:00 2001 From: Michael Vogiatzis Date: Thu, 9 Jul 2015 19:53:23 -0700 Subject: [PATCH 240/274] [DOCS] Added important updateStateByKey details Runs for *all* existing keys and returning "None" will remove the key-value pair. Author: Michael Vogiatzis Closes #7229 from mvogiatzis/patch-1 and squashes the following commits: e7a2946 [Michael Vogiatzis] Updated updateStateByKey text 00283ed [Michael Vogiatzis] Removed space c2656f9 [Michael Vogiatzis] Moved description farther up 0a42551 [Michael Vogiatzis] Added important updateStateByKey details --- docs/streaming-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e72d5580dae55..2f3013b533eb0 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -854,6 +854,8 @@ it with new information. To use this, you will have to do two steps. 1. Define the state update function - Specify with a function how to update the state using the previous state and the new values from an input stream. +In every batch, Spark will apply the state update function for all existing keys, regardless of whether they have new data in a batch or not. If the update function returns `None` then the key-value pair will be eliminated. + Let's illustrate this with an example. Say you want to maintain a running count of each word seen in a text data stream. Here, the running count is the state and it is an integer. We define the update function as: From e14b545d2dcbc4587688b4c46718d3680b0a2f67 Mon Sep 17 00:00:00 2001 From: Jonathan Alter Date: Fri, 10 Jul 2015 11:34:01 +0100 Subject: [PATCH 241/274] [SPARK-7977] [BUILD] Disallowing println Author: Jonathan Alter Closes #7093 from jonalter/SPARK-7977 and squashes the following commits: ccd44cc [Jonathan Alter] Changed println to log in ThreadingSuite 7fcac3e [Jonathan Alter] Reverting to println in ThreadingSuite 10724b6 [Jonathan Alter] Changing some printlns to logs in tests eeec1e7 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 0b1dcb4 [Jonathan Alter] More println cleanup aedaf80 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 925fd98 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 0c16fa3 [Jonathan Alter] Replacing some printlns with logs 45c7e05 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 5c8e283 [Jonathan Alter] Allowing println in audit-release examples 5b50da1 [Jonathan Alter] Allowing printlns in example files ca4b477 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 83ab635 [Jonathan Alter] Fixing new printlns 54b131f [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 1cd8a81 [Jonathan Alter] Removing some unnecessary comments and printlns b837c3a [Jonathan Alter] Disallowing println --- .../main/scala/org/apache/spark/Logging.scala | 2 ++ .../org/apache/spark/api/r/RBackend.scala | 2 ++ .../scala/org/apache/spark/api/r/RRDD.scala | 2 ++ .../org/apache/spark/deploy/Client.scala | 30 ++++++++-------- .../apache/spark/deploy/ClientArguments.scala | 4 +++ .../org/apache/spark/deploy/RRunner.scala | 2 ++ .../org/apache/spark/deploy/SparkSubmit.scala | 18 ++++++++++ .../spark/deploy/SparkSubmitArguments.scala | 4 +++ .../spark/deploy/client/TestExecutor.scala | 2 ++ .../history/HistoryServerArguments.scala | 2 ++ .../spark/deploy/master/MasterArguments.scala | 2 ++ .../MesosClusterDispatcherArguments.scala | 6 ++++ .../spark/deploy/worker/DriverWrapper.scala | 2 ++ .../spark/deploy/worker/WorkerArguments.scala | 4 +++ .../CoarseGrainedExecutorBackend.scala | 4 +++ .../input/FixedLengthBinaryInputFormat.scala | 7 ++-- .../spark/network/nio/BlockMessage.scala | 22 ------------ .../spark/network/nio/BlockMessageArray.scala | 34 ++++--------------- .../spark/network/nio/ConnectionManager.scala | 4 +++ .../scala/org/apache/spark/rdd/PipedRDD.scala | 4 +++ .../scheduler/EventLoggingListener.scala | 2 ++ .../apache/spark/scheduler/JobLogger.scala | 2 ++ .../org/apache/spark/ui/JettyUtils.scala | 2 ++ .../apache/spark/ui/UIWorkloadGenerator.scala | 6 +++- .../org/apache/spark/util/Distribution.scala | 6 ++++ .../spark/util/random/XORShiftRandom.scala | 2 ++ .../org/apache/spark/DistributedSuite.scala | 2 ++ .../scala/org/apache/spark/FailureSuite.scala | 2 ++ .../org/apache/spark/FileServerSuite.scala | 2 ++ .../org/apache/spark/ThreadingSuite.scala | 6 ++-- .../spark/deploy/SparkSubmitSuite.scala | 4 +++ .../spark/deploy/SparkSubmitUtilsSuite.scala | 2 ++ .../WholeTextFileRecordReaderSuite.scala | 8 ++--- .../metrics/InputOutputMetricsSuite.scala | 2 ++ .../spark/scheduler/ReplayListenerSuite.scala | 2 ++ .../spark/util/ClosureCleanerSuite.scala | 2 ++ .../org/apache/spark/util/UtilsSuite.scala | 2 ++ .../util/collection/SizeTrackerSuite.scala | 4 +++ .../spark/util/collection/SorterSuite.scala | 10 +++--- .../src/main/scala/SparkApp.scala | 2 ++ .../src/main/scala/SparkApp.scala | 2 ++ .../src/main/scala/GraphxApp.scala | 2 ++ .../sbt_app_hive/src/main/scala/HiveApp.scala | 2 ++ .../src/main/scala/SparkApp.scala | 2 ++ .../sbt_app_sql/src/main/scala/SqlApp.scala | 2 ++ .../src/main/scala/StreamingApp.scala | 2 ++ .../apache/spark/examples/BroadcastTest.scala | 2 ++ .../spark/examples/CassandraCQLTest.scala | 2 ++ .../apache/spark/examples/CassandraTest.scala | 2 ++ .../spark/examples/DFSReadWriteTest.scala | 2 ++ .../spark/examples/DriverSubmissionTest.scala | 2 ++ .../apache/spark/examples/GroupByTest.scala | 2 ++ .../org/apache/spark/examples/HBaseTest.scala | 2 ++ .../org/apache/spark/examples/HdfsTest.scala | 2 ++ .../org/apache/spark/examples/LocalALS.scala | 2 ++ .../apache/spark/examples/LocalFileLR.scala | 2 ++ .../apache/spark/examples/LocalKMeans.scala | 2 ++ .../org/apache/spark/examples/LocalLR.scala | 2 ++ .../org/apache/spark/examples/LocalPi.scala | 2 ++ .../org/apache/spark/examples/LogQuery.scala | 2 ++ .../spark/examples/MultiBroadcastTest.scala | 2 ++ .../examples/SimpleSkewedGroupByTest.scala | 2 ++ .../spark/examples/SkewedGroupByTest.scala | 2 ++ .../org/apache/spark/examples/SparkALS.scala | 2 ++ .../apache/spark/examples/SparkHdfsLR.scala | 2 ++ .../apache/spark/examples/SparkKMeans.scala | 2 ++ .../org/apache/spark/examples/SparkLR.scala | 2 ++ .../apache/spark/examples/SparkPageRank.scala | 2 ++ .../org/apache/spark/examples/SparkPi.scala | 2 ++ .../org/apache/spark/examples/SparkTC.scala | 2 ++ .../spark/examples/SparkTachyonHdfsLR.scala | 2 ++ .../spark/examples/SparkTachyonPi.scala | 2 ++ .../spark/examples/graphx/Analytics.scala | 2 ++ .../examples/graphx/LiveJournalPageRank.scala | 2 ++ .../examples/graphx/SynthBenchmark.scala | 2 ++ .../examples/ml/CrossValidatorExample.scala | 2 ++ .../examples/ml/DecisionTreeExample.scala | 2 ++ .../examples/ml/DeveloperApiExample.scala | 2 ++ .../apache/spark/examples/ml/GBTExample.scala | 2 ++ .../examples/ml/LinearRegressionExample.scala | 2 ++ .../ml/LogisticRegressionExample.scala | 2 ++ .../spark/examples/ml/MovieLensALS.scala | 2 ++ .../spark/examples/ml/OneVsRestExample.scala | 2 ++ .../examples/ml/RandomForestExample.scala | 2 ++ .../examples/ml/SimpleParamsExample.scala | 2 ++ .../ml/SimpleTextClassificationPipeline.scala | 2 ++ .../examples/mllib/BinaryClassification.scala | 2 ++ .../spark/examples/mllib/Correlations.scala | 2 ++ .../examples/mllib/CosineSimilarity.scala | 2 ++ .../spark/examples/mllib/DatasetExample.scala | 2 ++ .../examples/mllib/DecisionTreeRunner.scala | 2 ++ .../examples/mllib/DenseGaussianMixture.scala | 2 ++ .../spark/examples/mllib/DenseKMeans.scala | 2 ++ .../examples/mllib/FPGrowthExample.scala | 2 ++ .../mllib/GradientBoostedTreesRunner.scala | 2 ++ .../spark/examples/mllib/LDAExample.scala | 2 ++ .../examples/mllib/LinearRegression.scala | 2 ++ .../spark/examples/mllib/MovieLensALS.scala | 2 ++ .../mllib/MultivariateSummarizer.scala | 2 ++ .../PowerIterationClusteringExample.scala | 3 +- .../examples/mllib/RandomRDDGeneration.scala | 2 ++ .../spark/examples/mllib/SampledRDDs.scala | 2 ++ .../examples/mllib/SparseNaiveBayes.scala | 2 ++ .../mllib/StreamingKMeansExample.scala | 2 ++ .../mllib/StreamingLinearRegression.scala | 2 ++ .../mllib/StreamingLogisticRegression.scala | 2 ++ .../spark/examples/mllib/TallSkinnyPCA.scala | 2 ++ .../spark/examples/mllib/TallSkinnySVD.scala | 2 ++ .../spark/examples/sql/RDDRelation.scala | 2 ++ .../examples/sql/hive/HiveFromSpark.scala | 2 ++ .../examples/streaming/ActorWordCount.scala | 2 ++ .../examples/streaming/CustomReceiver.scala | 2 ++ .../streaming/DirectKafkaWordCount.scala | 2 ++ .../examples/streaming/FlumeEventCount.scala | 2 ++ .../streaming/FlumePollingEventCount.scala | 2 ++ .../examples/streaming/HdfsWordCount.scala | 2 ++ .../examples/streaming/KafkaWordCount.scala | 2 ++ .../examples/streaming/MQTTWordCount.scala | 4 +++ .../examples/streaming/NetworkWordCount.scala | 2 ++ .../examples/streaming/RawNetworkGrep.scala | 2 ++ .../RecoverableNetworkWordCount.scala | 2 ++ .../streaming/SqlNetworkWordCount.scala | 2 ++ .../streaming/StatefulNetworkWordCount.scala | 2 ++ .../streaming/TwitterAlgebirdCMS.scala | 2 ++ .../streaming/TwitterAlgebirdHLL.scala | 2 ++ .../streaming/TwitterPopularTags.scala | 2 ++ .../examples/streaming/ZeroMQWordCount.scala | 2 ++ .../clickstream/PageViewGenerator.scala | 2 ++ .../clickstream/PageViewStream.scala | 2 ++ .../kafka/DirectKafkaStreamSuite.scala | 2 +- .../streaming/KinesisWordCountASL.scala | 2 ++ .../spark/graphx/util/BytecodeUtils.scala | 1 - .../spark/graphx/util/GraphGenerators.scala | 4 +-- .../graphx/util/BytecodeUtilsSuite.scala | 2 ++ .../mllib/util/KMeansDataGenerator.scala | 2 ++ .../mllib/util/LinearDataGenerator.scala | 2 ++ .../LogisticRegressionDataGenerator.scala | 2 ++ .../spark/mllib/util/MFDataGenerator.scala | 2 ++ .../spark/mllib/util/SVMDataGenerator.scala | 2 ++ .../spark/ml/feature/VectorIndexerSuite.scala | 10 +++--- .../spark/mllib/linalg/VectorsSuite.scala | 6 ++-- .../spark/mllib/stat/CorrelationSuite.scala | 6 ++-- .../tree/GradientBoostedTreesSuite.scala | 10 +++--- .../spark/mllib/util/NumericParserSuite.scala | 2 +- project/SparkBuild.scala | 4 +++ .../apache/spark/repl/SparkCommandLine.scala | 2 ++ .../org/apache/spark/repl/SparkILoop.scala | 2 ++ .../apache/spark/repl/SparkILoopInit.scala | 2 ++ .../org/apache/spark/repl/SparkIMain.scala | 2 ++ .../org/apache/spark/repl/SparkILoop.scala | 2 ++ .../org/apache/spark/repl/SparkIMain.scala | 4 +++ .../apache/spark/repl/SparkReplReporter.scala | 2 ++ scalastyle-config.xml | 12 +++---- .../expressions/codegen/package.scala | 2 ++ .../spark/sql/catalyst/plans/QueryPlan.scala | 2 ++ .../spark/sql/catalyst/util/package.scala | 2 ++ .../apache/spark/sql/types/StructType.scala | 2 ++ .../scala/org/apache/spark/sql/Column.scala | 2 ++ .../org/apache/spark/sql/DataFrame.scala | 6 ++++ .../spark/sql/execution/debug/package.scala | 16 ++++----- .../hive/thriftserver/SparkSQLCLIDriver.scala | 12 ++++--- .../apache/spark/sql/hive/HiveContext.scala | 5 +-- .../org/apache/spark/sql/hive/HiveQl.scala | 5 +-- .../spark/sql/hive/client/ClientWrapper.scala | 2 ++ .../regression-test-SPARK-8489/Main.scala | 2 ++ .../sql/hive/HiveMetastoreCatalogSuite.scala | 6 ++-- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 ++ .../sql/hive/InsertIntoHiveTableSuite.scala | 2 -- .../sql/hive/MetastoreDataSourcesSuite.scala | 6 ++-- .../sql/hive/execution/HiveUDFSuite.scala | 1 - .../spark/streaming/dstream/DStream.scala | 2 ++ .../spark/streaming/util/RawTextSender.scala | 2 ++ .../spark/streaming/util/RecurringTimer.scala | 4 +-- .../spark/streaming/MasterFailureTest.scala | 4 +++ .../scheduler/JobGeneratorSuite.scala | 1 - .../spark/tools/GenerateMIMAIgnore.scala | 8 +++++ .../tools/JavaAPICompletenessChecker.scala | 4 +++ .../spark/tools/StoragePerfTester.scala | 4 +++ .../yarn/ApplicationMasterArguments.scala | 4 +++ .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../spark/deploy/yarn/ClientArguments.scala | 4 +++ .../spark/deploy/yarn/YarnClusterSuite.scala | 4 +++ 182 files changed, 478 insertions(+), 135 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 7fcb7830e7b0b..87ab099267b2f 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,6 +121,7 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { + // scalastyle:off println if (Utils.isInInterpreter) { val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { @@ -141,6 +142,7 @@ trait Logging { System.err.println(s"Spark was unable to load $defaultLogProps") } } + // scalastyle:on println } } Logging.initialized = true diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 1a5f2bca26c2b..b7e72d4d0ed0b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -95,7 +95,9 @@ private[spark] class RBackend { private[spark] object RBackend extends Logging { def main(args: Array[String]): Unit = { if (args.length < 1) { + // scalastyle:off println System.err.println("Usage: RBackend ") + // scalastyle:on println System.exit(-1) } val sparkRBackend = new RBackend() diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 524676544d6f5..ff1702f7dea48 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -161,7 +161,9 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( dataOut.write(elem.asInstanceOf[Array[Byte]]) } else if (deserializer == SerializationFormats.STRING) { // write string(for StringRRDD) + // scalastyle:off println printOut.println(elem) + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 71f7e2129116f..f03875a3e8c89 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -118,26 +118,26 @@ private class ClientEndpoint( def pollAndReportStatus(driverId: String) { // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread // is fine. - println("... waiting before polling master for driver state") + logInfo("... waiting before polling master for driver state") Thread.sleep(5000) - println("... polling master for driver state") + logInfo("... polling master for driver state") val statusResponse = activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => - println(s"ERROR: Cluster master did not recognize $driverId") + logError(s"ERROR: Cluster master did not recognize $driverId") System.exit(-1) case true => - println(s"State of $driverId is ${statusResponse.state.get}") + logInfo(s"State of $driverId is ${statusResponse.state.get}") // Worker node, if present (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => - println(s"Driver running on $hostPort ($id)") + logInfo(s"Driver running on $hostPort ($id)") case _ => } // Exception, if present statusResponse.exception.map { e => - println(s"Exception from cluster was: $e") + logError(s"Exception from cluster was: $e") e.printStackTrace() System.exit(-1) } @@ -148,7 +148,7 @@ private class ClientEndpoint( override def receive: PartialFunction[Any, Unit] = { case SubmitDriverResponse(master, success, driverId, message) => - println(message) + logInfo(message) if (success) { activeMasterEndpoint = master pollAndReportStatus(driverId.get) @@ -158,7 +158,7 @@ private class ClientEndpoint( case KillDriverResponse(master, driverId, success, message) => - println(message) + logInfo(message) if (success) { activeMasterEndpoint = master pollAndReportStatus(driverId) @@ -169,13 +169,13 @@ private class ClientEndpoint( override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master $remoteAddress.") + logError(s"Error connecting to master $remoteAddress.") lostMasters += remoteAddress // Note that this heuristic does not account for the fact that a Master can recover within // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This // is not currently a concern, however, because this client does not retry submissions. if (lostMasters.size >= masterEndpoints.size) { - println("No master is available, exiting.") + logError("No master is available, exiting.") System.exit(-1) } } @@ -183,18 +183,18 @@ private class ClientEndpoint( override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master ($remoteAddress).") - println(s"Cause was: $cause") + logError(s"Error connecting to master ($remoteAddress).") + logError(s"Cause was: $cause") lostMasters += remoteAddress if (lostMasters.size >= masterEndpoints.size) { - println("No master is available, exiting.") + logError("No master is available, exiting.") System.exit(-1) } } } override def onError(cause: Throwable): Unit = { - println(s"Error processing messages, exiting.") + logError(s"Error processing messages, exiting.") cause.printStackTrace() System.exit(-1) } @@ -209,10 +209,12 @@ private class ClientEndpoint( */ object Client { def main(args: Array[String]) { + // scalastyle:off println if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a future version of Spark") println("Use ./bin/spark-submit with \"--master spark://host:port\"") } + // scalastyle:on println val conf = new SparkConf() val driverArgs = new ClientArguments(args) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 42d3296062e6d..72cc330a398da 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -72,9 +72,11 @@ private[deploy] class ClientArguments(args: Array[String]) { cmd = "launch" if (!ClientArguments.isValidJarUrl(_jarUrl)) { + // scalastyle:off println println(s"Jar url '${_jarUrl}' is not in valid format.") println(s"Must be a jar file path in URL format " + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") + // scalastyle:on println printUsageAndExit(-1) } @@ -110,7 +112,9 @@ private[deploy] class ClientArguments(args: Array[String]) { | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin + // scalastyle:off println System.err.println(usage) + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index e99779f299785..4165740312e03 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -85,7 +85,9 @@ object RRunner { } System.exit(returnCode) } else { + // scalastyle:off println System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b1d6ec209d62b..4cec9017b8adb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -82,6 +82,7 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + // scalastyle:off println // Exposed for testing private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err @@ -102,11 +103,14 @@ object SparkSubmit { printStream.println("Type --help for more information.") exitFn(0) } + // scalastyle:on println def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { + // scalastyle:off println printStream.println(appArgs) + // scalastyle:on println } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs) @@ -160,7 +164,9 @@ object SparkSubmit { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { + // scalastyle:off println printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + // scalastyle:on println exitFn(1) } else { throw e @@ -178,7 +184,9 @@ object SparkSubmit { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { + // scalastyle:off println printStream.println("Running Spark using the REST application submission protocol.") + // scalastyle:on println doRunMain() } catch { // Fail over to use the legacy submission gateway @@ -558,6 +566,7 @@ object SparkSubmit { sysProps: Map[String, String], childMainClass: String, verbose: Boolean): Unit = { + // scalastyle:off println if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -565,6 +574,7 @@ object SparkSubmit { printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } + // scalastyle:on println val loader = if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { @@ -592,8 +602,10 @@ object SparkSubmit { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { + // scalastyle:off println printStream.println(s"Failed to load main class $childMainClass.") printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -766,7 +778,9 @@ private[spark] object SparkSubmitUtils { brr.setRoot(repo) brr.setName(s"repo-${i + 1}") cr.add(brr) + // scalastyle:off println printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + // scalastyle:on println } } @@ -829,7 +843,9 @@ private[spark] object SparkSubmitUtils { val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) val dd = new DefaultDependencyDescriptor(ri, false, false) dd.addDependencyConfiguration(ivyConfName, ivyConfName) + // scalastyle:off println printStream.println(s"${dd.getDependencyId} added as a dependency") + // scalastyle:on println md.addDependency(dd) } } @@ -896,9 +912,11 @@ private[spark] object SparkSubmitUtils { ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) new File(alternateIvyCache, "jars") } + // scalastyle:off println printStream.println( s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // scalastyle:on println // create a pattern matcher ivySettings.addMatcher(new GlobPatternMatcher) // create the dependency resolvers diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 6e3c0b21b33c2..ebb39c354dff1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -79,6 +79,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() + // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => @@ -86,6 +87,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } + // scalastyle:on println defaultProperties } @@ -452,6 +454,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { + // scalastyle:off println val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) @@ -541,6 +544,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println("CLI options:") outStream.println(getSqlShellOptions()) } + // scalastyle:on println SparkSubmit.exitFn(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala index c5ac45c6730d3..a98b1fa8f83a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala @@ -19,7 +19,9 @@ package org.apache.spark.deploy.client private[spark] object TestExecutor { def main(args: Array[String]) { + // scalastyle:off println println("Hello world!") + // scalastyle:on println while (true) { Thread.sleep(1000) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 4692d22651c93..18265df9faa2c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -56,6 +56,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin Utils.loadDefaultSparkProperties(conf, propertiesFile) private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( """ |Usage: HistoryServer [options] @@ -84,6 +85,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin | spark.history.fs.updateInterval How often to reload log data from storage | (in seconds, default: 10) |""".stripMargin) + // scalastyle:on println System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 435b9b12f83b8..44cefbc77f08e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -85,6 +85,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Master [options]\n" + "\n" + @@ -95,6 +96,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8080)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 894cb78d8591a..5accaf78d0a51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -54,7 +54,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--master" | "-m") :: value :: tail => if (!value.startsWith("mesos://")) { + // scalastyle:off println System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + // scalastyle:on println System.exit(1) } masterUrl = value.stripPrefix("mesos://") @@ -73,7 +75,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case Nil => { if (masterUrl == null) { + // scalastyle:off println System.err.println("--master is required") + // scalastyle:on println printUsageAndExit(1) } } @@ -83,6 +87,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: } private def printUsageAndExit(exitCode: Int): Unit = { + // scalastyle:off println System.err.println( "Usage: MesosClusterDispatcher [options]\n" + "\n" + @@ -96,6 +101,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: " Zookeeper for persistence\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index d1a12b01e78f7..2d6be3042c905 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -60,7 +60,9 @@ object DriverWrapper { rpcEnv.shutdown() case _ => + // scalastyle:off println System.err.println("Usage: DriverWrapper [options]") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 1d2ecab517613..e89d076802215 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -121,6 +121,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Worker [options] \n" + "\n" + @@ -136,6 +137,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8081)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } @@ -160,7 +162,9 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { } catch { case e: Exception => { totalMb = 2*1024 + // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + // scalastyle:on println } } // Leave out 1 GB for the operating system, but don't return a negative memory size diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 34d4cfdca7732..fcd76ec52742a 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -235,7 +235,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { argv = tail case Nil => case tail => + // scalastyle:off println System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + // scalastyle:on println printUsageAndExit() } } @@ -249,6 +251,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } private def printUsageAndExit() = { + // scalastyle:off println System.err.println( """ |"Usage: CoarseGrainedExecutorBackend [options] @@ -262,6 +265,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { | --worker-url | --user-class-path |""".stripMargin) + // scalastyle:on println System.exit(1) } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index c219d21fbefa9..532850dd57716 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -21,6 +21,8 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil /** @@ -39,7 +41,8 @@ private[spark] object FixedLengthBinaryInputFormat { } private[spark] class FixedLengthBinaryInputFormat - extends FileInputFormat[LongWritable, BytesWritable] { + extends FileInputFormat[LongWritable, BytesWritable] + with Logging { private var recordLength = -1 @@ -51,7 +54,7 @@ private[spark] class FixedLengthBinaryInputFormat recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) } if (recordLength <= 0) { - println("record length is less than 0, file cannot be split") + logDebug("record length is less than 0, file cannot be split") false } else { true diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index 67a376102994c..79cb0640c8672 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -57,16 +57,6 @@ private[nio] class BlockMessage() { } def set(buffer: ByteBuffer) { - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ typ = buffer.getInt() val idLength = buffer.getInt() val idBuilder = new StringBuilder(idLength) @@ -138,18 +128,6 @@ private[nio] class BlockMessage() { buffers += data } - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 7d0806f0c2580..f1c9ea8b64ca3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -43,16 +43,6 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) val newBlockMessages = new ArrayBuffer[BlockMessage]() val buffer = bufferMessage.buffers(0) buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ while (buffer.remaining() > 0) { val size = buffer.getInt() logDebug("Creating block message of size " + size + " bytes") @@ -86,23 +76,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) logDebug("Buffer list:") buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } } -private[nio] object BlockMessageArray { +private[nio] object BlockMessageArray extends Logging { def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() @@ -123,10 +101,10 @@ private[nio] object BlockMessageArray { } } val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") + logDebug("Block message array created") val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") + logDebug("Converted to buffer message") val totalSize = bufferMessage.size val newBuffer = ByteBuffer.allocate(totalSize) @@ -138,10 +116,11 @@ private[nio] object BlockMessageArray { }) newBuffer.flip val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) + logDebug("Copied to new buffer message, size = " + newBufferMessage.size) val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") + logDebug("Converted back to block message array") + // scalastyle:off println newBlockMessageArray.foreach(blockMessage => { blockMessage.getType match { case BlockMessage.TYPE_PUT_BLOCK => { @@ -154,6 +133,7 @@ private[nio] object BlockMessageArray { } } }) + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index c0bca2c4bc994..9143918790381 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -1016,7 +1016,9 @@ private[spark] object ConnectionManager { val conf = new SparkConf val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // scalastyle:off println println("Received [" + msg + "] from [" + id + "]") + // scalastyle:on println None }) @@ -1033,6 +1035,7 @@ private[spark] object ConnectionManager { System.gc() } + // scalastyle:off println def testSequentialSending(manager: ConnectionManager) { println("--------------------------") println("Sequential Sending") @@ -1150,4 +1153,5 @@ private[spark] object ConnectionManager { println() } } + // scalastyle:on println } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index dc60d48927624..defdabf95ac4b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -123,7 +123,9 @@ private[spark] class PipedRDD[T: ClassTag]( new Thread("stderr reader for " + command) { override def run() { for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + // scalastyle:off println System.err.println(line) + // scalastyle:on println } } }.start() @@ -133,6 +135,7 @@ private[spark] class PipedRDD[T: ClassTag]( override def run() { val out = new PrintWriter(proc.getOutputStream) + // scalastyle:off println // input the pipe context firstly if (printPipeContext != null) { printPipeContext(out.println(_)) @@ -144,6 +147,7 @@ private[spark] class PipedRDD[T: ClassTag]( out.println(elem) } } + // scalastyle:on println out.close() } }.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 529a5b2bf1a0d..62b05033a9281 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -140,7 +140,9 @@ private[spark] class EventLoggingListener( /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { val eventJson = JsonProtocol.sparkEventToJson(event) + // scalastyle:off println writer.foreach(_.println(compact(render(eventJson)))) + // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index e55b76c36cc5f..f96eb8ca0ae00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -125,7 +125,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener val date = new Date(System.currentTimeMillis()) writeInfo = dateFormat.get.format(date) + ": " + info } + // scalastyle:off println jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) + // scalastyle:on println } /** diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index f413c1d37fbb6..c8356467fab87 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -68,7 +68,9 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) + // scalastyle:on println } else { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index ba03acdb38cc5..5a8c2914314c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -38,9 +38,11 @@ private[spark] object UIWorkloadGenerator { def main(args: Array[String]) { if (args.length < 3) { + // scalastyle:off println println( - "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + + "Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + "[master] [FIFO|FAIR] [#job set (4 jobs per set)]") + // scalastyle:on println System.exit(1) } @@ -96,6 +98,7 @@ private[spark] object UIWorkloadGenerator { for ((desc, job) <- jobs) { new Thread { override def run() { + // scalastyle:off println try { setProperties(desc) job() @@ -106,6 +109,7 @@ private[spark] object UIWorkloadGenerator { } finally { barrier.release() } + // scalastyle:on println } }.start Thread.sleep(INTER_JOB_WAIT_MS) diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index 1bab707235b89..950b69f7db641 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -52,9 +52,11 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va } def showQuantiles(out: PrintStream = System.out): Unit = { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx)) @@ -64,8 +66,10 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va * @param out */ def summary(out: PrintStream = System.out) { + // scalastyle:off println out.println(statCounter) showQuantiles(out) + // scalastyle:on println } } @@ -80,8 +84,10 @@ private[spark] object Distribution { } def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") quantiles.foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index c4a7b4441c85c..85fb923cd9bc7 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -70,12 +70,14 @@ private[spark] object XORShiftRandom { * @param args takes one argument - the number of random numbers to generate */ def main(args: Array[String]): Unit = { + // scalastyle:off println if (args.length != 1) { println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") println("Usage: XORShiftRandom number_of_random_numbers_to_generate") System.exit(1) } println(benchmark(args(0).toInt)) + // scalastyle:on println } /** diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 9c191ed52206d..2300bcff4f118 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -107,7 +107,9 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc = new SparkContext(clusterUrl, "test") val accum = sc.accumulator(0) val thrown = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) + // scalastyle:on println } assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("failed 4 times")) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index a8c8c6f73fb5a..b099cd3fb7965 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -130,7 +130,9 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 2).foreach(x => println(a)) + // scalastyle:on println } assert(thrown2.getClass === classOf[SparkException]) assert(thrown2.getMessage.contains("NotSerializableException") || diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 6e65b0a8f6c76..876418aa13029 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -51,7 +51,9 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { val textFile = new File(testTempDir, "FileServerSuite.txt") val pw = new PrintWriter(textFile) + // scalastyle:off println pw.println("100") + // scalastyle:on println pw.close() val jarFile = new File(testTempDir, "test.jar") diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 6580139df6c60..48509f0759a3b 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -36,7 +36,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends SparkFunSuite with LocalSparkContext { +class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") @@ -130,8 +130,6 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { Thread.sleep(100) } if (running.get() != 4) { - println("Waited 1 second without seeing runningThreads = 4 (it was " + - running.get() + "); failing test") ThreadingSuiteState.failed.set(true) } number @@ -143,6 +141,8 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { } sem.acquire(2) if (ThreadingSuiteState.failed.get()) { + logError("Waited 1 second without seeing runningThreads = 4 (it was " + + ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2e05dec99b6bf..1b64c329b5d4b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -51,9 +51,11 @@ class SparkSubmitSuite /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } /** Returns true if the script exits and the given search string is printed. */ @@ -81,6 +83,7 @@ class SparkSubmitSuite } } + // scalastyle:off println test("prints usage on empty input") { testPrematureExit(Array[String](), "Usage: spark-submit") } @@ -491,6 +494,7 @@ class SparkSubmitSuite appArgs.executorMemory should be ("2.3g") } } + // scalastyle:on println // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index c9b435a9228d3..01ece1a10f46d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -41,9 +41,11 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } override def beforeAll() { diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 63947df3d43a2..8a199459c1ddf 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.io.Text -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -36,7 +36,7 @@ import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, Gzi * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary * directory is created as fake input. Temporal storage would be deleted in the end. */ -class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll { +class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { private var sc: SparkContext = _ private var factory: CompressionCodecFactory = _ @@ -85,7 +85,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl */ test("Correctness of WholeTextFileRecordReader.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, false) @@ -109,7 +109,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl test("Correctness of WholeTextFileRecordReader with GzipCodec.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, true) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 9e4d34fb7d382..d3218a548efc7 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -60,7 +60,9 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) for (x <- 1 to numRecords) { + // scalastyle:off println pw.println(RandomUtils.nextInt(0, numBuckets)) + // scalastyle:on println } pw.close() diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index ff3fa95ec32ae..4e3defb43a021 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -52,8 +52,10 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) + // scalastyle:off println writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart)))) writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd)))) + // scalastyle:on println writer.close() val conf = EventLoggingListenerSuite.getLoggingConf(logFilePath) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 1053c6caf7718..480722a5ac182 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -375,6 +375,7 @@ class TestCreateNullValue { // parameters of the closure constructor. This allows us to test whether // null values are created correctly for each type. val nestedClosure = () => { + // scalastyle:off println if (s.toString == "123") { // Don't really output them to avoid noisy println(bo) println(c) @@ -389,6 +390,7 @@ class TestCreateNullValue { val closure = () => { println(getX) } + // scalastyle:on println ClosureCleaner.clean(closure) } nestedClosure() diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 251a797dc28a2..c7638507c88c6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -684,7 +684,9 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val buffer = new CircularBuffer(25) val stream = new java.io.PrintStream(buffer, true, "UTF-8") + // scalastyle:off println stream.println("test circular test circular test circular test circular test circular") + // scalastyle:on println assert(buffer.toString === "t circular test circular\n") } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala index 5a5919fca2469..4f382414a8dd7 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala @@ -103,7 +103,9 @@ private object SizeTrackerSuite { */ def main(args: Array[String]): Unit = { if (args.size < 1) { + // scalastyle:off println println("Usage: SizeTrackerSuite [num elements]") + // scalastyle:on println System.exit(1) } val numElements = args(0).toInt @@ -180,11 +182,13 @@ private object SizeTrackerSuite { baseTimes: Seq[Long], sampledTimes: Seq[Long], unsampledTimes: Seq[Long]): Unit = { + // scalastyle:off println println(s"Average times for $testName (ms):") println(" Base - " + averageTime(baseTimes)) println(" SizeTracker (sampled) - " + averageTime(sampledTimes)) println(" SizeEstimator (unsampled) - " + averageTime(unsampledTimes)) println() + // scalastyle:on println } def time(f: => Unit): Long = { diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index b2f5d9009ee5d..fefa5165db197 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.util.collection import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.random.XORShiftRandom -class SorterSuite extends SparkFunSuite { +class SorterSuite extends SparkFunSuite with Logging { test("equivalent to Arrays.sort") { val rand = new XORShiftRandom(123) @@ -74,7 +74,7 @@ class SorterSuite extends SparkFunSuite { /** Runs an experiment several times. */ def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = { if (skip) { - println(s"Skipped experiment $name.") + logInfo(s"Skipped experiment $name.") return } @@ -86,11 +86,11 @@ class SorterSuite extends SparkFunSuite { while (i < 10) { val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare)) next10 += time - println(s"$name: Took $time ms") + logInfo(s"$name: Took $time ms") i += 1 } - println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") + logInfo(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") } /** diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala index fc03fec9866a6..61d91c70e9709 100644 --- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -59,3 +60,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala index 0be8e64fbfabd..9f7ae75d0b477 100644 --- a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -37,3 +38,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala index 24c7f8d667296..2f0b6ef9a5672 100644 --- a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala +++ b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import org.apache.spark.{SparkContext, SparkConf} @@ -51,3 +52,4 @@ object GraphXApp { println("Test succeeded") } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index 5111bc0adb772..4a980ec071ae4 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -55,3 +56,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala index 9f85066501472..adc25b57d6aa5 100644 --- a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -31,3 +32,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index cc86ef45858c9..69c1154dc0955 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -57,3 +58,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala index 58a662bd9b2e8..d6a074687f4a1 100644 --- a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala +++ b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -61,3 +62,4 @@ object SparkStreamingExample { ssc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 4c129dbe2d12d..d812262fd87dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -52,3 +53,4 @@ object BroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 023bb3ee2d108..36832f51d2ad4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ + // scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -140,3 +141,4 @@ object CassandraCQLTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ec689474aecb0..96ef3e198e380 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -130,6 +131,7 @@ object CassandraTest { sc.stop() } } +// scalastyle:on println /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 1f12034ce0f57..d651fe4d6ee75 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.io.File @@ -136,3 +137,4 @@ object DFSReadWriteTest { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index e757283823fc3..c42df2b8845d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.collection.JavaConversions._ @@ -46,3 +47,4 @@ object DriverSubmissionTest { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 15f6678648b29..fa4a3afeecd19 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -53,3 +54,4 @@ object GroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 95c96111c9b1f..244742327a907 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.hadoop.hbase.client.HBaseAdmin @@ -62,3 +63,4 @@ object HBaseTest { admin.close() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index ed2b38e2ca6f8..124dc9af6390f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark._ @@ -41,3 +42,4 @@ object HdfsTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 3d5259463003d..af5f216f28ba4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -142,3 +143,4 @@ object LocalALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index ac2ea35bbd0e0..9c8aae53cf48d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -73,3 +74,4 @@ object LocalFileLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 04fc0a033014a..e7b28d38bdfc6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -119,3 +120,4 @@ object LocalKMeans { println("Final centers: " + kPoints) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index c3fc74a116c0a..4f6b092a59ca5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -77,3 +78,4 @@ object LocalLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index ee6b3ee34aeb2..3d923625f11b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -33,3 +34,4 @@ object LocalPi { println("Pi is roughly " + 4 * count / 100000.0) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 75c82117cbad2..a80de10f4610a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -83,3 +84,4 @@ object LogQuery { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 2a5c0c0defe13..61ce9db914f9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.rdd.RDD @@ -53,3 +54,4 @@ object MultiBroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 5291ab81f459e..3b0b00fe4dd0a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -67,3 +68,4 @@ object SimpleSkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 017d4e1e5ce13..719e2176fed3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -57,3 +58,4 @@ object SkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 30c4261551837..69799b7c2bb30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -144,3 +145,4 @@ object SparkALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 9099c2fcc90b3..505ea5a4c7a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -97,3 +98,4 @@ object SparkHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index b514d9123f5e7..c56e1124ad415 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import breeze.linalg.{Vector, DenseVector, squaredDistance} @@ -100,3 +101,4 @@ object SparkKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 1e6b4fb0c7514..d265c227f4ed2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -86,3 +87,4 @@ object SparkLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index bd7894f184c4c..0fd79660dd196 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.SparkContext._ @@ -74,3 +75,4 @@ object SparkPageRank { ctx.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 35b8dd6c29b66..818d4f2b81f82 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -37,3 +38,4 @@ object SparkPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 772cd897f5140..95072071ccddb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.util.Random @@ -70,3 +71,4 @@ object SparkTC { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 4393b99e636b6..cfbdae02212a5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -94,3 +95,4 @@ object SparkTachyonHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index 7743f7968b100..e46ac655beb58 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -46,3 +47,4 @@ object SparkTachyonPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 409721b01c8fd..8dd6c9706e7df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import scala.collection.mutable @@ -151,3 +152,4 @@ object Analytics extends Logging { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index f6f8d9f90c275..da3ffca1a6f2a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx /** @@ -42,3 +43,4 @@ object LiveJournalPageRank { Analytics.main(args.patch(0, List("pagerank"), 0)) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 3ec20d594b784..46e52aacd90bb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ @@ -128,3 +129,4 @@ object SynthBenchmark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 6c0af20461d3b..14b358d46f6ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -110,3 +111,4 @@ object CrossValidatorExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 54e4073941056..f28671f7869fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -355,3 +356,4 @@ object DecisionTreeExample { println(s" Root mean squared error (RMSE): $RMSE") } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 7b8cc21ed8982..78f31b4ffe56a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -181,3 +182,4 @@ private class MyLogisticRegressionModel( copyValues(new MyLogisticRegressionModel(uid, weights), extra) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index 33905277c7341..f4a15f806ea81 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -236,3 +237,4 @@ object GBTExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index b54466fd48bc5..b73299fb12d3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -140,3 +141,4 @@ object LinearRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 3cf193f353fbc..7682557127b51 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -157,3 +158,4 @@ object LogisticRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 25f21113bf622..cd411397a4b9d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scopt.OptionParser @@ -178,3 +179,4 @@ object MovieLensALS { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index 6927eb8f275cf..bab31f585b0ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} @@ -183,3 +184,4 @@ object OneVsRestExample { (NANO.toSeconds(t1 - t0), result) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 9f7cad68a4594..109178f4137b2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -244,3 +245,4 @@ object RandomForestExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index a0561e2573fc9..58d7b67674ff7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -100,3 +101,4 @@ object SimpleParamsExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 1324b066c30c3..960280137cbf9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.beans.BeanInfo @@ -89,3 +90,4 @@ object SimpleTextClassificationPipeline { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a113653810b93..1a4016f76c2ad 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -153,3 +154,4 @@ object BinaryClassification { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index e49129c4e7844..026d4ecc6d10a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -91,3 +92,4 @@ object Correlations { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index cb1abbd18fd4d..69988cc1b9334 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -106,3 +107,4 @@ object CosineSimilarity { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 520893b26d595..dc13f82488af7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.io.File @@ -119,3 +120,4 @@ object DatasetExample { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 3381941673db8..57ffe3dd2524f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.language.reflectiveCalls @@ -368,3 +369,4 @@ object DecisionTreeRunner { } // scalastyle:on structural.type } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index f8c71ccabc43b..1fce4ba7efd60 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -65,3 +66,4 @@ object DenseGaussianMixture { println() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 14cc5cbb679c5..380d85d60e7b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -107,3 +108,4 @@ object DenseKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 13f24a1e59610..14b930550d554 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -80,3 +81,4 @@ object FPGrowthExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 7416fb5a40848..e16a6bf033574 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -145,3 +146,4 @@ object GradientBoostedTreesRunner { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 31d629f853161..75b0f69cf91aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.text.BreakIterator @@ -302,3 +303,4 @@ private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Se } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 6a456ba7ec07b..8878061a0970b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -134,3 +135,4 @@ object LinearRegression { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 99588b0984ab2..e43a6f2864c73 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.collection.mutable @@ -189,3 +190,4 @@ object MovieLensALS { math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean()) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 6e4e2d07f284b..5f839c75dd581 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -97,3 +98,4 @@ object MultivariateSummarizer { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 6d8b806569dfd..0723223954610 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -154,4 +155,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } - +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index 924b586e3af99..bee85ba0f9969 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.random.RandomRDDs @@ -58,3 +59,4 @@ object RandomRDDGeneration { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 663c12734af68..6963f43e082c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.util.MLUtils @@ -125,3 +126,4 @@ object SampledRDDs { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index f1ff4e6911f5e..f81fc292a3bd1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -100,3 +101,4 @@ object SparseNaiveBayes { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala index 8bb12d2ee9ed2..af03724a8ac62 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.SparkConf @@ -75,3 +76,4 @@ object StreamingKMeansExample { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1a95048bbfe2d..b4a5dca031abd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -69,3 +70,4 @@ object StreamingLinearRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index e1998099c2d78..b42f4cb5f9338 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -71,3 +72,4 @@ object StreamingLogisticRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 3cd9cb743e309..464fbd385ab5d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnyPCA { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 4d6690318615a..65b4bc46f0266 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnySVD { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index b11e32047dc34..2cc56f04e5c1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} @@ -73,3 +74,4 @@ object RDDRelation { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index b7ba60ec28155..bf40bd1ef13df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql.hive import com.google.common.io.{ByteStreams, Files} @@ -77,3 +78,4 @@ object HiveFromSpark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 016de4c63d1d2..e9c9907198769 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import scala.collection.mutable.LinkedList @@ -170,3 +171,4 @@ object ActorWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 30269a7ccae97..28e9bf520e568 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.{InputStreamReader, BufferedReader, InputStream} @@ -100,3 +101,4 @@ class CustomReceiver(host: String, port: Int) } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index fbe394de4a179..bd78526f8c299 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import kafka.serializer.StringDecoder @@ -70,3 +71,4 @@ object DirectKafkaWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala index 20e7df7c45b1b..91e52e4eff5a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -66,3 +67,4 @@ object FlumeEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 1cc8c8d5c23b6..2bdbc37e2a289 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -65,3 +66,4 @@ object FlumePollingEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index 4b4667fec44e6..1f282d437dc38 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -53,3 +54,4 @@ object HdfsWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 60416ee343544..b40d17e9c2fa3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.util.HashMap @@ -101,3 +102,4 @@ object KafkaWordCountProducer { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 813c8554f5193..d772ae309f40d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.eclipse.paho.client.mqttv3._ @@ -96,8 +97,10 @@ object MQTTWordCount { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println System.err.println( "Usage: MQTTWordCount ") + // scalastyle:on println System.exit(1) } @@ -113,3 +116,4 @@ object MQTTWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 2cd8073dada14..9a57fe286d1ae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -57,3 +58,4 @@ object NetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index a9aaa445bccb6..5322929d177b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -58,3 +59,4 @@ object RawNetworkGrep { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 751b30ea15782..9916882e4f94a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.File @@ -108,3 +109,4 @@ object RecoverableNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 5a6b9216a3fbc..ed617754cbf1c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -99,3 +100,4 @@ object SQLContextSingleton { instance } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 345d0bc441351..02ba1c2eed0f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -78,3 +79,4 @@ object StatefulNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index c10de84a80ffe..825c671a929b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird._ @@ -113,3 +114,4 @@ object TwitterAlgebirdCMS { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 62db5e663b8af..49826ede70418 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird.HyperLogLogMonoid @@ -90,3 +91,4 @@ object TwitterAlgebirdHLL { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index f253d75b279f7..49cee1b43c2dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -82,3 +83,4 @@ object TwitterPopularTags { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index e99d1baa72b9f..6ac9a72c37941 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import akka.actor.ActorSystem @@ -97,3 +98,4 @@ object ZeroMQWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 889f052c70263..bea7a47cb2855 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import java.net.ServerSocket @@ -108,3 +109,4 @@ object PageViewGenerator { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index fbacaee98690f..ec7d39da8b2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import org.apache.spark.SparkContext._ @@ -107,3 +108,4 @@ object PageViewStream { ssc.start() } } +// scalastyle:on println diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 8e1715f6dbb95..5b3c79444aa68 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -111,7 +111,7 @@ class DirectKafkaStreamSuite rdd }.foreachRDD { rdd => for (o <- offsetRanges) { - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") } val collected = rdd.mapPartitionsWithIndex { (i, iter) => // For each partition, get size of the range in the partition, diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index be8b62d3cc6ba..de749626ec09c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.nio.ByteBuffer @@ -272,3 +273,4 @@ private[streaming] object StreamingExamples extends Logging { } } } +// scalastyle:on println diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index be6b9047d932d..5c07b415cd796 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -66,7 +66,6 @@ private[graphx] object BytecodeUtils { val finder = new MethodInvocationFinder(c.getName, m) getClassReader(c).accept(finder, 0) for (classMethod <- finder.methodsInvoked) { - // println(classMethod) if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { return true } else if (!seen.contains(classMethod)) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 9591c4e9b8f4e..989e226305265 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx.Edge import org.apache.spark.graphx.impl.GraphImpl /** A collection of graph generating functions. */ -object GraphGenerators { +object GraphGenerators extends Logging { val RMATa = 0.45 val RMATb = 0.15 @@ -142,7 +142,7 @@ object GraphGenerators { var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { - println(edges.size + " edges") + logDebug(edges.size + " edges") } edges += addEdge(numVertices) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index 186d0cc2a977b..61e44dcab578c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx.util import org.apache.spark.SparkFunSuite +// scalastyle:off println class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass @@ -102,6 +103,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { private val c = {e: TestClass => println(e.baz)} } +// scalastyle:on println object BytecodeUtilsSuite { class TestClass(val foo: Int, val bar: Long) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala index 6eaebaf7dba9f..e6bcff48b022c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -64,8 +64,10 @@ object KMeansDataGenerator { def main(args: Array[String]) { if (args.length < 6) { + // scalastyle:off println println("Usage: KMeansGenerator " + " []") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index b4e33c98ba7e5..87eeb5db05d26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -153,8 +153,10 @@ object LinearDataGenerator { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println println("Usage: LinearDataGenerator " + " [num_examples] [num_features] [num_partitions]") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 9d802678c4a77..c09cbe69bb971 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -64,8 +64,10 @@ object LogisticRegressionDataGenerator { def main(args: Array[String]) { if (args.length != 5) { + // scalastyle:off println println("Usage: LogisticRegressionGenerator " + " ") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index bd73a866c8a82..16f430599a515 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -55,8 +55,10 @@ import org.apache.spark.rdd.RDD object MFDataGenerator { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println println("Usage: MFDataGenerator " + " [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index a8e30cc9d730c..ad20b7694a779 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -37,8 +37,10 @@ object SVMDataGenerator { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println println("Usage: SVMGenerator " + " [num_examples] [num_features] [num_partitions]") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 8c85c96d5c6d8..03120c828ca96 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} @@ -27,7 +27,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { import VectorIndexerSuite.FeatureData @@ -113,11 +113,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { model.transform(sparsePoints1) // should work intercept[SparkException] { model.transform(densePoints2).collect() - println("Did not throw error when fit, transform were called on vectors of different lengths") + logInfo("Did not throw error when fit, transform were called on vectors of different lengths") } intercept[SparkException] { vectorIndexer.fit(badPoints) - println("Did not throw error when fitting vectors of different lengths in same RDD.") + logInfo("Did not throw error when fitting vectors of different lengths in same RDD.") } } @@ -196,7 +196,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { } } catch { case e: org.scalatest.exceptions.TestFailedException => - println(errMsg) + logError(errMsg) throw e } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index c4ae0a16f7c04..178d95a7b94ec 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -21,10 +21,10 @@ import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ -class VectorsSuite extends SparkFunSuite { +class VectorsSuite extends SparkFunSuite with Logging { val arr = Array(0.1, 0.0, 0.3, 0.4) val n = 4 @@ -142,7 +142,7 @@ class VectorsSuite extends SparkFunSuite { malformatted.foreach { s => intercept[SparkException] { Vectors.parse(s) - println(s"Didn't detect malformatted string $s.") + logInfo(s"Didn't detect malformatted string $s.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index c292ced75e870..c3eeda012571c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.stat import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { // test input data val xData = Array(1.0, 0.0, -2.0) @@ -146,7 +146,7 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = { for (i <- 0 until A.rows; j <- 0 until A.cols) { if (!approxEqual(A(i, j), B(i, j), threshold)) { - println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j)) + logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j)) return false } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 84dd3b342d4c0..2521b3342181a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GradientBoostedTrees]]. */ -class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext { +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { @@ -50,7 +50,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) } catch { case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + s" subsamplingRate=$subsamplingRate") throw e } @@ -80,7 +80,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae") } catch { case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + s" subsamplingRate=$subsamplingRate") throw e } @@ -111,7 +111,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9) } catch { case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + s" subsamplingRate=$subsamplingRate") throw e } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index fa4f74d71b7e7..16d7c3ab39b03 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -33,7 +33,7 @@ class NumericParserSuite extends SparkFunSuite { malformatted.foreach { s => intercept[SparkException] { NumericParser.parse(s) - println(s"Didn't detect malformatted string $s.") + throw new RuntimeException(s"Didn't detect malformatted string $s.") } } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 3408c6d51ed4c..4291b0be2a616 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -69,6 +69,7 @@ object SparkBuild extends PomBuild { import scala.collection.mutable var isAlphaYarn = false var profiles: mutable.Seq[String] = mutable.Seq("sbt") + // scalastyle:off println if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.") profiles ++= Seq("spark-ganglia-lgpl") @@ -88,6 +89,7 @@ object SparkBuild extends PomBuild { println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.") profiles ++= Seq("yarn") } + // scalastyle:on println profiles } @@ -96,8 +98,10 @@ object SparkBuild extends PomBuild { case None => backwardCompatibility case Some(v) => if (backwardCompatibility.nonEmpty) + // scalastyle:off println println("Note: We ignore environment variables, when use of profile is detected in " + "conjunction with environment variable.") + // scalastyle:on println v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala index 6480e2d24e044..24fbbc12c08da 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala @@ -39,6 +39,8 @@ class SparkCommandLine(args: List[String], override val settings: Settings) } def this(args: List[String]) { + // scalastyle:off println this(args, str => Console.println("Error: " + str)) + // scalastyle:on println } } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 2b235525250c2..8f7f9074d3f03 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -1101,7 +1101,9 @@ object SparkILoop extends Logging { val s = super.readLine() // helping out by printing the line being interpreted. if (s != null) + // scalastyle:off println output.println(s) + // scalastyle:on println s } } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 05faef8786d2c..bd3314d94eed6 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -80,11 +80,13 @@ private[repl] trait SparkILoopInit { if (!initIsComplete) withLock { while (!initIsComplete) initLoopCondition.await() } if (initError != null) { + // scalastyle:off println println(""" |Failed to initialize the REPL due to an unexpected error. |This is a bug, please, report it along with the error diagnostics printed below. |%s.""".stripMargin.format(initError) ) + // scalastyle:on println false } else true } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 35fb625645022..8791618bd355e 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1761,7 +1761,9 @@ object SparkIMain { if (intp.totalSilence) () else super.printMessage(msg) } + // scalastyle:off println else Console.println(msg) + // scalastyle:on println } } } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 7a5e94da5cbf3..3c90287249497 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -943,7 +943,9 @@ object SparkILoop { val s = super.readLine() // helping out by printing the line being interpreted. if (s != null) + // scalastyle:off println output.println(s) + // scalastyle:on println s } } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 1cb910f376060..56c009a4e38e7 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -129,7 +129,9 @@ class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings } private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" private val logScope = scala.sys.props contains "scala.repl.scope" + // scalastyle:off println private def scopelog(msg: String) = if (logScope) Console.err.println(msg) + // scalastyle:on println // argument is a thunk to execute after init is done def initialize(postInitSignal: => Unit) { @@ -1297,8 +1299,10 @@ class SparkISettings(intp: SparkIMain) { def deprecation_=(x: Boolean) = { val old = intp.settings.deprecation.value intp.settings.deprecation.value = x + // scalastyle:off println if (!old && x) println("Enabled -deprecation output.") else if (old && !x) println("Disabled -deprecation output.") + // scalastyle:on println } def deprecation: Boolean = intp.settings.deprecation.value diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala index 0711ed4871bb6..272f81eca92c1 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala @@ -42,7 +42,9 @@ class SparkReplReporter(intp: SparkIMain) extends ConsoleReporter(intp.settings, } else super.printMessage(msg) } + // scalastyle:off println else Console.println("[init] " + msg) + // scalastyle:on println } override def displayPrompt() { diff --git a/scalastyle-config.xml b/scalastyle-config.xml index d6f927b6fa803..49611703798e8 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -141,12 +141,8 @@ This file is divided into 3 sections: Tests must extend org.apache.spark.SparkFunSuite instead. - - - - - - + + ^println$ + + + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 7f1b12cdd5800..606fecbe06e47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -67,8 +67,10 @@ package object codegen { outfile.write(generatedBytes) outfile.close() + // scalastyle:off println println( s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!) + // scalastyle:on println } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2f545bb432165..b89e3382f06a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -154,7 +154,9 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def schemaString: String = schema.treeString /** Prints out the schema in the tree format */ + // scalastyle:off println def printSchema(): Unit = println(schemaString) + // scalastyle:on println /** * A prefix string used when printing the plan. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 07054166a5e88..71293475ca0f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -124,7 +124,9 @@ package object util { val startTime = System.nanoTime() val ret = f val endTime = System.nanoTime() + // scalastyle:off println println(s"${(endTime - startTime).toDouble / 1000000}ms") + // scalastyle:on println ret } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e0b8ff91786a7..b8097403ec3cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -250,7 +250,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru builder.toString() } + // scalastyle:off println def printTreeString(): Unit = println(treeString) + // scalastyle:on println private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { fields.foreach(field => field.buildFormattedString(prefix, builder)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f201c8ea8a110..10250264625b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -860,11 +860,13 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ def explain(extended: Boolean): Unit = { + // scalastyle:off println if (extended) { println(expr) } else { println(expr.prettyString) } + // scalastyle:on println } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d7966651b1948..830fba35bb7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -308,7 +308,9 @@ class DataFrame private[sql]( * @group basic * @since 1.3.0 */ + // scalastyle:off println def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println /** * Prints the plans (logical and physical) to the console for debugging purposes. @@ -319,7 +321,9 @@ class DataFrame private[sql]( ExplainCommand( queryExecution.logical, extended = extended).queryExecution.executedPlan.executeCollect().map { + // scalastyle:off println r => println(r.getString(0)) + // scalastyle:on println } } @@ -392,7 +396,9 @@ class DataFrame private[sql]( * @group action * @since 1.5.0 */ + // scalastyle:off println def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + // scalastyle:on println /** * Returns a [[DataFrameNaFunctions]] for working with missing data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 2964edac1aba2..e6081cb05bc2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -24,7 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String import scala.collection.mutable.HashSet -import org.apache.spark.{AccumulatorParam, Accumulator} +import org.apache.spark.{AccumulatorParam, Accumulator, Logging} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -57,7 +57,7 @@ package object debug { * Augments [[DataFrame]]s with debug methods. */ @DeveloperApi - implicit class DebugQuery(query: DataFrame) { + implicit class DebugQuery(query: DataFrame) extends Logging { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() @@ -66,7 +66,7 @@ package object debug { visited += new TreeNodeRef(s) DebugNode(s) } - println(s"Results returned: ${debugPlan.execute().count()}") + logDebug(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { case d: DebugNode => d.dumpStats() case _ => @@ -82,11 +82,11 @@ package object debug { TypeCheck(s) } try { - println(s"Results returned: ${debugPlan.execute().count()}") + logDebug(s"Results returned: ${debugPlan.execute().count()}") } catch { case e: Exception => def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause) - println(s"Deepest Error: ${unwrap(e)}") + logDebug(s"Deepest Error: ${unwrap(e)}") } } } @@ -119,11 +119,11 @@ package object debug { val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { - println(s"== ${child.simpleString} ==") - println(s"Tuples output: ${tupleCount.value}") + logDebug(s"== ${child.simpleString} ==") + logDebug(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case(attr, metric) => val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") - println(s" ${attr.name} ${attr.dataType}: $actualDataTypes") + logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 039cfa40d26b3..f66a17b20915f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -40,7 +40,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils -private[hive] object SparkSQLCLIDriver { +private[hive] object SparkSQLCLIDriver extends Logging { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') private var transport: TSocket = _ @@ -164,7 +164,7 @@ private[hive] object SparkSQLCLIDriver { } } catch { case e: FileNotFoundException => - System.err.println(s"Could not open input file for reading. (${e.getMessage})") + logError(s"Could not open input file for reading. (${e.getMessage})") System.exit(3) } @@ -180,14 +180,14 @@ private[hive] object SparkSQLCLIDriver { val historyFile = historyDirectory + File.separator + ".hivehistory" reader.setHistory(new History(new File(historyFile))) } else { - System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + logWarning("WARNING: Directory for Hive history file: " + historyDirectory + " does not exist. History will not be available during this session.") } } catch { case e: Exception => - System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + logWarning("WARNING: Encountered an error while trying to initialize Hive's " + "history file. History will not be available during this session.") - System.err.println(e.getMessage) + logWarning(e.getMessage) } val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") @@ -270,6 +270,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) if (proc != null) { + // scalastyle:off println if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || proc.isInstanceOf[AddResourceProcessor]) { val driver = new SparkSQLDriver @@ -336,6 +337,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } ret = proc.run(cmd_1).getResponseCode } + // scalastyle:on println } ret } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index bbc39b892b79e..4684d48aff889 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ @@ -65,12 +66,12 @@ private[hive] class HiveQLDialect extends ParserDialect { * * @since 1.0.0 */ -class HiveContext(sc: SparkContext) extends SQLContext(sc) { +class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { self => import HiveContext._ - println("create HiveContext") + logDebug("create HiveContext") /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2de7a99c122fd..7fc517b646b20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -73,7 +74,7 @@ private[hive] case class CreateTableAsSelect( } /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl { +private[hive] object HiveQl extends Logging { protected val nativeCommands = Seq( "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", @@ -186,7 +187,7 @@ private[hive] object HiveQl { .map(ast => Option(ast).map(_.transform(rule)).orNull)) } catch { case e: Exception => - println(dumpTree(n)) + logError(dumpTree(n).toString) throw e } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index cbd2bf6b5eede..9d83ca6c113dc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -360,7 +360,9 @@ private[hive] class ClientWrapper( case _ => if (state.out != null) { + // scalastyle:off println state.out.println(tokens(0) + " " + cmd_1) + // scalastyle:on println } Seq(proc.run(cmd_1).getResponseCode.toString) } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala index 0e428ba1d7456..2590040f2ec1c 100644 --- a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.hive.HiveContext */ object Main { def main(args: Array[String]) { + // scalastyle:off println println("Running regression test for SPARK-8489.") val sc = new SparkContext("local", "testing") val hc = new HiveContext(sc) @@ -38,6 +39,7 @@ object Main { val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) df.collect() println("Regression test for SPARK-8489 success!") + // scalastyle:on println sc.stop() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index e9bb32667936c..983c013bcf86a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends SparkFunSuite { +class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { test("struct field should accept underscore in sub-column name") { val metastr = "struct" @@ -41,7 +41,7 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite { test("duplicated metastore relations") { import TestHive.implicits._ val df = TestHive.sql("SELECT * FROM src") - println(df.queryExecution) + logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index a38ed23b5cf9a..917900e5f46dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -90,8 +90,10 @@ class HiveSparkSubmitSuite "SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome ).run(ProcessLogger( + // scalastyle:off println (line: String) => { println(s"out> $line") }, (line: String) => { println(s"err> $line") } + // scalastyle:on println )) try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index aa5dbe2db6903..508695919e9a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -86,8 +86,6 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { val message = intercept[QueryExecutionException] { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") }.getMessage - - println("message!!!!" + message) } test("Double create does not fail when allowExisting = true") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index cc294bc3e8bc3..d910af22c3dd1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidInputException +import org.apache.spark.Logging import org.apache.spark.sql._ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive @@ -40,7 +41,8 @@ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll + with Logging { override val sqlContext = TestHive var jsonFilePath: String = _ @@ -415,7 +417,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA |) """.stripMargin) - sql("DROP TABLE jsonTable").collect().foreach(println) + sql("DROP TABLE jsonTable").collect().foreach(i => logInfo(i.toString)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index eaaa88e17002b..1bde5922b5278 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -315,7 +315,6 @@ class PairUDF extends GenericUDF { ) override def evaluate(args: Array[DeferredObject]): AnyRef = { - println("Type = %s".format(args(0).getClass.getName)) Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 192aa6a139bcb..1da0b0a54df07 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -720,12 +720,14 @@ abstract class DStream[T: ClassTag] ( def foreachFunc: (RDD[T], Time) => Unit = { (rdd: RDD[T], time: Time) => { val firstNum = rdd.take(num + 1) + // scalastyle:off println println("-------------------------------------------") println("Time: " + time) println("-------------------------------------------") firstNum.take(num).foreach(println) if (firstNum.length > num) println("...") println() + // scalastyle:on println } } new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index ca2f319f174a2..6addb96752038 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -35,7 +35,9 @@ private[streaming] object RawTextSender extends Logging { def main(args: Array[String]) { if (args.length != 4) { + // scalastyle:off println System.err.println("Usage: RawTextSender ") + // scalastyle:on println System.exit(1) } // Parse the arguments using a pattern match diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index c8eef833eb431..dd32ad5ad811d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -106,7 +106,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: } private[streaming] -object RecurringTimer { +object RecurringTimer extends Logging { def main(args: Array[String]) { var lastRecurTime = 0L @@ -114,7 +114,7 @@ object RecurringTimer { def onRecur(time: Long) { val currentTime = System.currentTimeMillis() - println("" + currentTime + ": " + (currentTime - lastRecurTime)) + logInfo("" + currentTime + ": " + (currentTime - lastRecurTime)) lastRecurTime = currentTime } val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index e0f14fd954280..6e9d4431090a2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -43,6 +43,7 @@ object MasterFailureTest extends Logging { @volatile var setupCalled = false def main(args: Array[String]) { + // scalastyle:off println if (args.size < 2) { println( "Usage: MasterFailureTest <# batches> " + @@ -60,6 +61,7 @@ object MasterFailureTest extends Logging { testUpdateStateByKey(directory, numBatches, batchDuration) println("\n\nSUCCESS\n\n") + // scalastyle:on println } def testMap(directory: String, numBatches: Int, batchDuration: Duration) { @@ -291,10 +293,12 @@ object MasterFailureTest extends Logging { } // Log the output + // scalastyle:off println println("Expected output, size = " + expectedOutput.size) println(expectedOutput.mkString("[", ",", "]")) println("Output, size = " + output.size) println(output.mkString("[", ",", "]")) + // scalastyle:on println // Match the output with the expected output output.foreach(o => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 7865b06c2e3c2..a2dbae149f311 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -76,7 +76,6 @@ class JobGeneratorSuite extends TestSuiteBase { if (time.milliseconds == longBatchTime) { while (waitLatch.getCount() > 0) { waitLatch.await() - println("Await over") } } }) diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 595ded6ae67fa..9483d2b692ab5 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -92,7 +92,9 @@ object GenerateMIMAIgnore { ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } catch { + // scalastyle:off println case _: Throwable => println("Error instrumenting class:" + className) + // scalastyle:on println } } (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) @@ -108,7 +110,9 @@ object GenerateMIMAIgnore { .filter(_.contains("$$")).map(classSymbol.fullName + "." + _) } catch { case t: Throwable => + // scalastyle:off println println("[WARN] Unable to detect inner functions for class:" + classSymbol.fullName) + // scalastyle:on println Seq.empty[String] } } @@ -128,12 +132,14 @@ object GenerateMIMAIgnore { getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-class-excludes") .writeAll(previousContents + privateClasses.mkString("\n")) + // scalastyle:off println println("Created : .generated-mima-class-excludes in current directory.") val previousMembersContents = Try(File(".generated-mima-member-excludes").lines) .getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-member-excludes").writeAll(previousMembersContents + privateMembers.mkString("\n")) println("Created : .generated-mima-member-excludes in current directory.") + // scalastyle:on println } @@ -174,7 +180,9 @@ object GenerateMIMAIgnore { try { classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) } catch { + // scalastyle:off println case _: Throwable => println("Unable to load:" + entry) + // scalastyle:on println } } classes diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 583823c90c5c6..856ea177a9a10 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -323,11 +323,14 @@ object JavaAPICompletenessChecker { val missingMethods = javaEquivalents -- javaMethods for (method <- missingMethods) { + // scalastyle:off println println(method) + // scalastyle:on println } } def main(args: Array[String]) { + // scalastyle:off println println("Missing RDD methods") printMissingMethods(classOf[RDD[_]], classOf[JavaRDD[_]]) println() @@ -359,5 +362,6 @@ object JavaAPICompletenessChecker { println("Missing PairDStream methods") printMissingMethods(classOf[PairDStreamFunctions[_, _]], classOf[JavaPairDStream[_, _]]) println() + // scalastyle:on println } } diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index baa97616eaff3..0dc2861253f17 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -85,7 +85,9 @@ object StoragePerfTester { latch.countDown() } catch { case e: Exception => + // scalastyle:off println println("Exception in child thread: " + e + " " + e.getMessage) + // scalastyle:on println System.exit(1) } } @@ -97,9 +99,11 @@ object StoragePerfTester { val bytesPerSecond = totalBytes.get() / time val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong + // scalastyle:off println System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits)) System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile))) System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong))) + // scalastyle:on println executor.shutdown() sc.stop() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 68e9f6b4db7f4..37f793763367e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -85,7 +85,9 @@ class ApplicationMasterArguments(val args: Array[String]) { } if (primaryPyFile != null && primaryRFile != null) { + // scalastyle:off println System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + // scalastyle:on println System.exit(-1) } @@ -93,6 +95,7 @@ class ApplicationMasterArguments(val args: Array[String]) { } def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + // scalastyle:off println if (unknownParam != null) { System.err.println("Unknown/unsupported param " + unknownParam) } @@ -111,6 +114,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) """.stripMargin) + // scalastyle:on println System.exit(exitCode) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4d52ae774ea00..f0af6f875f523 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -938,7 +938,7 @@ private[spark] class Client( object Client extends Logging { def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { - println("WARNING: This client is deprecated and will be removed in a " + + logWarning("WARNING: This client is deprecated and will be removed in a " + "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 19d1bbff9993f..20d63d40cf605 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -123,6 +123,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new SparkException("Executor cores must not be less than " + "spark.task.cpus.") } + // scalastyle:off println if (isClusterMode) { for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) { if (sparkConf.contains(key)) { @@ -144,11 +145,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .map(_.toInt) .foreach { cores => amCores = cores } } + // scalastyle:on println } private def parseArgs(inputArgs: List[String]): Unit = { var args = inputArgs + // scalastyle:off println while (!args.isEmpty) { args match { case ("--jar") :: value :: tail => @@ -253,6 +256,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException(getUsageMessage(args)) } } + // scalastyle:on println if (primaryPyFile != null && primaryRFile != null) { throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 335e966519c7c..547863d9a0739 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -328,12 +328,14 @@ private object YarnClusterDriver extends Logging with Matchers { def main(args: Array[String]): Unit = { if (args.length != 1) { + // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClusterDriver [result file] """.stripMargin) + // scalastyle:on println System.exit(1) } @@ -386,12 +388,14 @@ private object YarnClasspathTest { def main(args: Array[String]): Unit = { if (args.length != 2) { + // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClasspathTest [driver result file] [executor result file] """.stripMargin) + // scalastyle:on println System.exit(1) } From 11e22b74a080ea58fb9410b5cc6fa4c03f9198f2 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 10 Jul 2015 16:22:49 +0100 Subject: [PATCH 242/274] [SPARK-7944] [SPARK-8013] Remove most of the Spark REPL fork for Scala 2.11 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR removes most of the code in the Spark REPL for Scala 2.11 and leaves just a couple of overridden methods in `SparkILoop` in order to: - change welcome message - restrict available commands (like `:power`) - initialize Spark context The two codebases have diverged and it's extremely hard to backport fixes from the upstream REPL. This somewhat radical step is absolutely necessary in order to fix other REPL tickets (like SPARK-8013 - Hive Thrift server for 2.11). BTW, the Scala REPL has fixed the serialization-unfriendly wrappers thanks to ScrapCodes's work in [#4522](https://github.com/scala/scala/pull/4522) All tests pass and I tried the `spark-shell` on our Mesos cluster with some simple jobs (including with additional jars), everything looked good. As soon as Scala 2.11.7 is out we need to upgrade and get a shaded `jline` dependency, clearing the way for SPARK-8013. /cc pwendell Author: Iulian Dragos Closes #6903 from dragos/issue/no-spark-repl-fork and squashes the following commits: c596c6f [Iulian Dragos] Merge branch 'master' into issue/no-spark-repl-fork 2b1a305 [Iulian Dragos] Removed spaces around multiple imports. 0ce67a6 [Iulian Dragos] Remove -verbose flag for java compiler (added by mistake in an earlier commit). 10edaf9 [Iulian Dragos] Keep the jline dependency only in the 2.10 build. 529293b [Iulian Dragos] Add back Spark REPL files to rat-excludes, since they are part of the 2.10 real. d85370d [Iulian Dragos] Remove jline dependency from the Spark REPL. b541930 [Iulian Dragos] Merge branch 'master' into issue/no-spark-repl-fork 2b15962 [Iulian Dragos] Change jline dependency and bump Scala version. b300183 [Iulian Dragos] Rename package and add license on top of the file, remove files from rat-excludes and removed `-Yrepl-sync` per reviewer’s request. 9d46d85 [Iulian Dragos] Fix SPARK-7944. abcc7cb [Iulian Dragos] Remove the REPL forked code. --- pom.xml | 18 +- repl/pom.xml | 19 +- .../scala/org/apache/spark/repl/Main.scala | 16 +- .../apache/spark/repl/SparkExprTyper.scala | 86 -- .../org/apache/spark/repl/SparkILoop.scala | 971 +----------- .../org/apache/spark/repl/SparkIMain.scala | 1323 ----------------- .../org/apache/spark/repl/SparkImports.scala | 201 --- .../spark/repl/SparkJLineCompletion.scala | 350 ----- .../spark/repl/SparkMemberHandlers.scala | 221 --- .../apache/spark/repl/SparkReplReporter.scala | 55 - .../org/apache/spark/repl/ReplSuite.scala | 11 +- 11 files changed, 90 insertions(+), 3181 deletions(-) delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkImports.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala diff --git a/pom.xml b/pom.xml index 172fdef4c73da..c2ebc1a11e770 100644 --- a/pom.xml +++ b/pom.xml @@ -341,11 +341,6 @@ - - ${jline.groupid} - jline - ${jline.version} - com.twitter chill_${scala.binary.version} @@ -1826,6 +1821,15 @@ ${scala.version} org.scala-lang + + + + ${jline.groupid} + jline + ${jline.version} + + + @@ -1844,10 +1848,8 @@ scala-2.11 - 2.11.6 + 2.11.7 2.11 - 2.12.1 - jline diff --git a/repl/pom.xml b/repl/pom.xml index 370b2bc2fa8ed..70c9bd7c01296 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -38,11 +38,6 @@ - - ${jline.groupid} - jline - ${jline.version} - org.apache.spark spark-core_${scala.binary.version} @@ -161,6 +156,20 @@ + + scala-2.10 + + !scala-2.11 + + + + ${jline.groupid} + jline + ${jline.version} + + + + scala-2.11 diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index f4f4b626988e9..eed4a379afa60 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -17,13 +17,14 @@ package org.apache.spark.repl +import java.io.File + +import scala.tools.nsc.Settings + import org.apache.spark.util.Utils import org.apache.spark._ import org.apache.spark.sql.SQLContext -import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter.SparkILoop - object Main extends Logging { val conf = new SparkConf() @@ -32,7 +33,8 @@ object Main extends Logging { val outputDir = Utils.createTempDir(rootDir) val s = new Settings() s.processArguments(List("-Yrepl-class-based", - "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", + "-classpath", getAddedJars.mkString(File.pathSeparator)), true) val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ @@ -48,7 +50,6 @@ object Main extends Logging { Option(sparkContext).map(_.stop) } - def getAddedJars: Array[String] = { val envJars = sys.env.get("ADD_JARS") if (envJars.isDefined) { @@ -84,10 +85,9 @@ object Main extends Logging { val loader = Utils.getContextOrSparkClassLoader try { sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) - .newInstance(sparkContext).asInstanceOf[SQLContext] + .newInstance(sparkContext).asInstanceOf[SQLContext] logInfo("Created sql context (with Hive support)..") - } - catch { + } catch { case _: java.lang.ClassNotFoundException | _: java.lang.NoClassDefFoundError => sqlContext = new SQLContext(sparkContext) logInfo("Created sql context..") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala deleted file mode 100644 index 8e519fa67f649..0000000000000 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Paul Phillips - */ - -package scala.tools.nsc -package interpreter - -import scala.tools.nsc.ast.parser.Tokens.EOF - -trait SparkExprTyper { - val repl: SparkIMain - - import repl._ - import global.{ reporter => _, Import => _, _ } - import naming.freshInternalVarName - - def symbolOfLine(code: String): Symbol = { - def asExpr(): Symbol = { - val name = freshInternalVarName() - // Typing it with a lazy val would give us the right type, but runs - // into compiler bugs with things like existentials, so we compile it - // behind a def and strip the NullaryMethodType which wraps the expr. - val line = "def " + name + " = " + code - - interpretSynthetic(line) match { - case IR.Success => - val sym0 = symbolOfTerm(name) - // drop NullaryMethodType - sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) - case _ => NoSymbol - } - } - def asDefn(): Symbol = { - val old = repl.definedSymbolList.toSet - - interpretSynthetic(code) match { - case IR.Success => - repl.definedSymbolList filterNot old match { - case Nil => NoSymbol - case sym :: Nil => sym - case syms => NoSymbol.newOverloaded(NoPrefix, syms) - } - case _ => NoSymbol - } - } - def asError(): Symbol = { - interpretSynthetic(code) - NoSymbol - } - beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() - } - - private var typeOfExpressionDepth = 0 - def typeOfExpression(expr: String, silent: Boolean = true): Type = { - if (typeOfExpressionDepth > 2) { - repldbg("Terminating typeOfExpression recursion for expression: " + expr) - return NoType - } - typeOfExpressionDepth += 1 - // Don't presently have a good way to suppress undesirable success output - // while letting errors through, so it is first trying it silently: if there - // is an error, and errors are desired, then it re-evaluates non-silently - // to induce the error message. - try beSilentDuring(symbolOfLine(expr).tpe) match { - case NoType if !silent => symbolOfLine(expr).tpe // generate error - case tpe => tpe - } - finally typeOfExpressionDepth -= 1 - } - - // This only works for proper types. - def typeOfTypeString(typeString: String): Type = { - def asProperType(): Option[Type] = { - val name = freshInternalVarName() - val line = "def %s: %s = ???" format (name, typeString) - interpretSynthetic(line) match { - case IR.Success => - val sym0 = symbolOfTerm(name) - Some(sym0.asMethod.returnType) - case _ => None - } - } - beSilentDuring(asProperType()) getOrElse NoType - } -} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 3c90287249497..bf609ff0f65fc 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -1,88 +1,64 @@ -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Alexander Spoon +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -package scala -package tools.nsc -package interpreter +package org.apache.spark.repl -import scala.language.{ implicitConversions, existentials } -import scala.annotation.tailrec -import Predef.{ println => _, _ } -import interpreter.session._ -import StdReplTags._ -import scala.reflect.api.{Mirror, Universe, TypeCreator} -import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName } -import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } -import scala.reflect.{ClassTag, classTag} -import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader } -import ScalaClassLoader._ -import scala.reflect.io.{ File, Directory } -import scala.tools.util._ -import scala.collection.generic.Clearable -import scala.concurrent.{ ExecutionContext, Await, Future, future } -import ExecutionContext.Implicits._ -import java.io.{ BufferedReader, FileReader } +import java.io.{BufferedReader, FileReader} -/** The Scala interactive shell. It provides a read-eval-print loop - * around the Interpreter class. - * After instantiation, clients should call the main() method. - * - * If no in0 is specified, then input will come from the console, and - * the class will attempt to provide input editing feature such as - * input history. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - * @version 1.2 - */ -class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) - extends AnyRef - with LoopCommands -{ - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) - def this() = this(None, new JPrintWriter(Console.out, true)) -// -// @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp -// @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i - - var in: InteractiveReader = _ // the input stream from which commands come - var settings: Settings = _ - var intp: SparkIMain = _ +import Predef.{println => _, _} +import scala.util.Properties.{jdkHome, javaVersion, versionString, javaVmName} - var globalFuture: Future[Boolean] = _ +import scala.tools.nsc.interpreter.{JPrintWriter, ILoop} +import scala.tools.nsc.Settings +import scala.tools.nsc.util.stringFromStream - protected def asyncMessage(msg: String) { - if (isReplInfo || isReplPower) - echoAndRefresh(msg) - } +/** + * A Spark-specific interactive shell. + */ +class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) + extends ILoop(in0, out) { + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) + def this() = this(None, new JPrintWriter(Console.out, true)) def initializeSpark() { intp.beQuietDuring { - command( """ + processLine(""" @transient val sc = { val _sc = org.apache.spark.repl.Main.createSparkContext() println("Spark context available as sc.") _sc } """) - command( """ + processLine(""" @transient val sqlContext = { val _sqlContext = org.apache.spark.repl.Main.createSQLContext() println("SQL context available as sqlContext.") _sqlContext } """) - command("import org.apache.spark.SparkContext._") - command("import sqlContext.implicits._") - command("import sqlContext.sql") - command("import org.apache.spark.sql.functions._") + processLine("import org.apache.spark.SparkContext._") + processLine("import sqlContext.implicits._") + processLine("import sqlContext.sql") + processLine("import org.apache.spark.sql.functions._") } } /** Print a welcome message */ - def printWelcome() { + override def printWelcome() { import org.apache.spark.SPARK_VERSION echo("""Welcome to ____ __ @@ -98,877 +74,42 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) echo("Type :help for more information.") } - override def echoCommandMessage(msg: String) { - intp.reporter printUntruncatedMessage msg - } - - // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) - def history = in.history - - // classpath entries added via :cp - var addedClasspath: String = "" - - /** A reverse list of commands to replay if the user requests a :replay */ - var replayCommandStack: List[String] = Nil - - /** A list of commands to replay if the user requests a :replay */ - def replayCommands = replayCommandStack.reverse - - /** Record a command for replay should the user request a :replay */ - def addReplay(cmd: String) = replayCommandStack ::= cmd - - def savingReplayStack[T](body: => T): T = { - val saved = replayCommandStack - try body - finally replayCommandStack = saved - } - def savingReader[T](body: => T): T = { - val saved = in - try body - finally in = saved - } - - /** Close the interpreter and set the var to null. */ - def closeInterpreter() { - if (intp ne null) { - intp.close() - intp = null - } - } - - class SparkILoopInterpreter extends SparkIMain(settings, out) { - outer => - - override lazy val formatting = new Formatting { - def prompt = SparkILoop.this.prompt - } - override protected def parentClassLoader = - settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader ) - } - - /** Create a new interpreter. */ - def createInterpreter() { - if (addedClasspath != "") - settings.classpath append addedClasspath - - intp = new SparkILoopInterpreter - } - - /** print a friendly help message */ - def helpCommand(line: String): Result = { - if (line == "") helpSummary() - else uniqueCommand(line) match { - case Some(lc) => echo("\n" + lc.help) - case _ => ambiguousError(line) - } - } - private def helpSummary() = { - val usageWidth = commands map (_.usageMsg.length) max - val formatStr = "%-" + usageWidth + "s %s" - - echo("All commands can be abbreviated, e.g. :he instead of :help.") - - commands foreach { cmd => - echo(formatStr.format(cmd.usageMsg, cmd.help)) - } - } - private def ambiguousError(cmd: String): Result = { - matchingCommands(cmd) match { - case Nil => echo(cmd + ": no such command. Type :help for help.") - case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") - } - Result(keepRunning = true, None) - } - private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) - private def uniqueCommand(cmd: String): Option[LoopCommand] = { - // this lets us add commands willy-nilly and only requires enough command to disambiguate - matchingCommands(cmd) match { - case List(x) => Some(x) - // exact match OK even if otherwise appears ambiguous - case xs => xs find (_.name == cmd) - } - } - - /** Show the history */ - lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { - override def usage = "[num]" - def defaultLines = 20 - - def apply(line: String): Result = { - if (history eq NoHistory) - return "No history available." - - val xs = words(line) - val current = history.index - val count = try xs.head.toInt catch { case _: Exception => defaultLines } - val lines = history.asStrings takeRight count - val offset = current - lines.size + 1 - - for ((line, index) <- lines.zipWithIndex) - echo("%3d %s".format(index + offset, line)) - } - } - - // When you know you are most likely breaking into the middle - // of a line being typed. This softens the blow. - protected def echoAndRefresh(msg: String) = { - echo("\n" + msg) - in.redrawLine() - } - protected def echo(msg: String) = { - out println msg - out.flush() - } - - /** Search the history */ - def searchHistory(_cmdline: String) { - val cmdline = _cmdline.toLowerCase - val offset = history.index - history.size + 1 - - for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) - echo("%d %s".format(index + offset, line)) - } - - private val currentPrompt = Properties.shellPromptString - - /** Prompt to print when awaiting input */ - def prompt = currentPrompt - import LoopCommand.{ cmd, nullary } - /** Standard commands **/ - lazy val standardCommands = List( - cmd("cp", "", "add a jar or directory to the classpath", addClasspath), - cmd("edit", "|", "edit history", editCommand), - cmd("help", "[command]", "print this summary or command-specific help", helpCommand), - historyCommand, - cmd("h?", "", "search the history", searchHistory), - cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), - //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand), - cmd("javap", "", "disassemble a file or class name", javapCommand), - cmd("line", "|", "place line(s) at the end of history", lineCommand), - cmd("load", "", "interpret lines in a file", loadCommand), - cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand), - // nullary("power", "enable power user mode", powerCmd), - nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)), - nullary("replay", "reset execution and replay all previous commands", replay), - nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), - cmd("save", "", "save replayable session to a file", saveCommand), - shCommand, - cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings), - nullary("silent", "disable/enable automatic printing of results", verbosity), -// cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), -// cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand), - nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) - ) - - /** Power user commands */ -// lazy val powerCommands: List[LoopCommand] = List( -// cmd("phase", "", "set the implicit phase for power commands", phaseCommand) -// ) - - private def importsCommand(line: String): Result = { - val tokens = words(line) - val handlers = intp.languageWildcardHandlers ++ intp.importHandlers - - handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { - case (handler, idx) => - val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) - val imps = handler.implicitSymbols - val found = tokens filter (handler importsSymbolNamed _) - val typeMsg = if (types.isEmpty) "" else types.size + " types" - val termMsg = if (terms.isEmpty) "" else terms.size + " terms" - val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" - val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") - val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") - - intp.reporter.printMessage("%2d) %-30s %s%s".format( - idx + 1, - handler.importString, - statsMsg, - foundMsg - )) - } - } - - private def findToolsJar() = PathResolver.SupplementalLocations.platformTools + private val blockedCommands = Set("implicits", "javap", "power", "type", "kind") - private def addToolsJarToLoader() = { - val cl = findToolsJar() match { - case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) - case _ => intp.classLoader - } - if (Javap.isAvailable(cl)) { - repldbg(":javap available.") - cl - } - else { - repldbg(":javap unavailable: no tools.jar at " + jdkHome) - intp.classLoader - } - } -// -// protected def newJavap() = -// JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp)) -// -// private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) - - // Still todo: modules. -// private def typeCommand(line0: String): Result = { -// line0.trim match { -// case "" => ":type [-v] " -// case s => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") -// } -// } - -// private def kindCommand(expr: String): Result = { -// expr.trim match { -// case "" => ":kind [-v] " -// case s => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") -// } -// } - - private def warningsCommand(): Result = { - if (intp.lastWarnings.isEmpty) - "Can't find any cached warnings." - else - intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } - } - - private def changeSettings(args: String): Result = { - def showSettings() = { - for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString) - } - def updateSettings() = { - // put aside +flag options - val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+")) - val tmps = new Settings - val (ok, leftover) = tmps.processArguments(rest, processAll = true) - if (!ok) echo("Bad settings request.") - else if (leftover.nonEmpty) echo("Unprocessed settings.") - else { - // boolean flags set-by-user on tmp copy should be off, not on - val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting]) - val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg)) - // update non-flags - settings.processArguments(nonbools, processAll = true) - // also snag multi-value options for clearing, e.g. -Ylog: and -language: - for { - s <- settings.userSetSettings - if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting] - if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init)) - } s match { - case c: Clearable => c.clear() - case _ => - } - def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = { - for (b <- bs) - settings.lookupSetting(name(b)) match { - case Some(s) => - if (s.isInstanceOf[Settings#BooleanSetting]) setter(s) - else echo(s"Not a boolean flag: $b") - case _ => - echo(s"Not an option: $b") - } - } - update(minuses, identity, _.tryToSetFromPropertyValue("false")) // turn off - update(pluses, "-" + _.drop(1), _.tryToSet(Nil)) // turn on - } - } - if (args.isEmpty) showSettings() else updateSettings() - } - - private def javapCommand(line: String): Result = { -// if (javap == null) -// ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) -// else if (line == "") -// ":javap [-lcsvp] [path1 path2 ...]" -// else -// javap(words(line)) foreach { res => -// if (res.isError) return "Failed: " + res.value -// else res.show() -// } - } - - private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent" - - private def phaseCommand(name: String): Result = { -// val phased: Phased = power.phased -// import phased.NoPhaseName -// -// if (name == "clear") { -// phased.set(NoPhaseName) -// intp.clearExecutionWrapper() -// "Cleared active phase." -// } -// else if (name == "") phased.get match { -// case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" -// case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) -// } -// else { -// val what = phased.parse(name) -// if (what.isEmpty || !phased.set(what)) -// "'" + name + "' does not appear to represent a valid phase." -// else { -// intp.setExecutionWrapper(pathToPhaseWrapper) -// val activeMessage = -// if (what.toString.length == name.length) "" + what -// else "%s (%s)".format(what, name) -// -// "Active phase is now: " + activeMessage -// } -// } - } + /** Standard commands **/ + lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] = + standardCommands.filter(cmd => !blockedCommands(cmd.name)) /** Available commands */ - def commands: List[LoopCommand] = standardCommands ++ ( - // if (isReplPower) - // powerCommands - // else - Nil - ) - - val replayQuestionMessage = - """|That entry seems to have slain the compiler. Shall I replay - |your session? I can re-run each line except the last one. - |[y/n] - """.trim.stripMargin - - private val crashRecovery: PartialFunction[Throwable, Boolean] = { - case ex: Throwable => - val (err, explain) = ( - if (intp.isInitializeComplete) - (intp.global.throwableAsString(ex), "") - else - (ex.getMessage, "The compiler did not initialize.\n") - ) - echo(err) - - ex match { - case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("\nUnrecoverable error.") - throw ex - case _ => - def fn(): Boolean = - try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) - catch { case _: RuntimeException => false } - - if (fn()) replay() - else echo("\nAbandoning crashed session.") - } - true - } - - // return false if repl should exit - def processLine(line: String): Boolean = { - import scala.concurrent.duration._ - Await.ready(globalFuture, 60.seconds) - - (line ne null) && (command(line) match { - case Result(false, _) => false - case Result(_, Some(line)) => addReplay(line) ; true - case _ => true - }) - } - - private def readOneLine() = { - out.flush() - in readLine prompt - } - - /** The main read-eval-print loop for the repl. It calls - * command() for each line of input, and stops when - * command() returns false. - */ - @tailrec final def loop() { - if ( try processLine(readOneLine()) catch crashRecovery ) - loop() - } - - /** interpret all lines from a specified file */ - def interpretAllFrom(file: File) { - savingReader { - savingReplayStack { - file applyReader { reader => - in = SimpleReader(reader, out, interactive = false) - echo("Loading " + file + "...") - loop() - } - } - } - } - - /** create a new interpreter and replay the given commands */ - def replay() { - reset() - if (replayCommandStack.isEmpty) - echo("Nothing to replay.") - else for (cmd <- replayCommands) { - echo("Replaying: " + cmd) // flush because maybe cmd will have its own output - command(cmd) - echo("") - } - } - def resetCommand() { - echo("Resetting interpreter state.") - if (replayCommandStack.nonEmpty) { - echo("Forgetting this session history:\n") - replayCommands foreach echo - echo("") - replayCommandStack = Nil - } - if (intp.namedDefinedTerms.nonEmpty) - echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) - if (intp.definedTypes.nonEmpty) - echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) - - reset() - } - def reset() { - intp.reset() - unleashAndSetPhase() - } - - def lineCommand(what: String): Result = editCommand(what, None) - - // :edit id or :edit line - def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR")) - - def editCommand(what: String, editor: Option[String]): Result = { - def diagnose(code: String) = { - echo("The edited code is incomplete!\n") - val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") - if (errless) echo("The compiler reports no errors.") - } - def historicize(text: String) = history match { - case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true - case _ => false - } - def edit(text: String): Result = editor match { - case Some(ed) => - val tmp = File.makeTemp() - tmp.writeAll(text) - try { - val pr = new ProcessResult(s"$ed ${tmp.path}") - pr.exitCode match { - case 0 => - tmp.safeSlurp() match { - case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.") - case Some(edited) => - echo(edited.lines map ("+" + _) mkString "\n") - val res = intp interpret edited - if (res == IR.Incomplete) diagnose(edited) - else { - historicize(edited) - Result(lineToRecord = Some(edited), keepRunning = true) - } - case None => echo("Can't read edited text. Did you delete it?") - } - case x => echo(s"Error exit from $ed ($x), ignoring") - } - } finally { - tmp.delete() - } - case None => - if (historicize(text)) echo("Placing text in recent history.") - else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text") - } - - // if what is a number, use it as a line number or range in history - def isNum = what forall (c => c.isDigit || c == '-' || c == '+') - // except that "-" means last value - def isLast = (what == "-") - if (isLast || !isNum) { - val name = if (isLast) intp.mostRecentVar else what - val sym = intp.symbolOfIdent(name) - intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match { - case Some(req) => edit(req.line) - case None => echo(s"No symbol in scope: $what") - } - } else try { - val s = what - // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur) - val (start, len) = - if ((s indexOf '+') > 0) { - val (a,b) = s splitAt (s indexOf '+') - (a.toInt, b.drop(1).toInt) - } else { - (s indexOf '-') match { - case -1 => (s.toInt, 1) - case 0 => val n = s.drop(1).toInt ; (history.index - n, n) - case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n) - case i => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n) - } - } - import scala.collection.JavaConverters._ - val index = (start - 1) max 0 - val text = history match { - case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n" - case _ => history.asStrings.slice(index, index + len) mkString "\n" - } - edit(text) - } catch { - case _: NumberFormatException => echo(s"Bad range '$what'") - echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)") - } - } - - /** fork a shell and run a command */ - lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { - override def usage = "" - def apply(line: String): Result = line match { - case "" => showUsage() - case _ => - val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})" - intp interpret toRun - () - } - } - - def withFile[A](filename: String)(action: File => A): Option[A] = { - val res = Some(File(filename)) filter (_.exists) map action - if (res.isEmpty) echo("That file does not exist") // courtesy side-effect - res - } - - def loadCommand(arg: String) = { - var shouldReplay: Option[String] = None - withFile(arg)(f => { - interpretAllFrom(f) - shouldReplay = Some(":load " + arg) - }) - Result(keepRunning = true, shouldReplay) - } - - def saveCommand(filename: String): Result = ( - if (filename.isEmpty) echo("File name is required.") - else if (replayCommandStack.isEmpty) echo("No replay commands in session") - else File(filename).printlnAll(replayCommands: _*) - ) - - def addClasspath(arg: String): Unit = { - val f = File(arg).normalize - if (f.exists) { - addedClasspath = ClassPath.join(addedClasspath, f.path) - val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) - echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) - replay() - } - else echo("The path '" + f + "' doesn't seem to exist.") - } - - def powerCmd(): Result = { - if (isReplPower) "Already in power mode." - else enablePowerMode(isDuringInit = false) - } - def enablePowerMode(isDuringInit: Boolean) = { - replProps.power setValue true - unleashAndSetPhase() - // asyncEcho(isDuringInit, power.banner) - } - private def unleashAndSetPhase() { - if (isReplPower) { - // power.unleash() - // Set the phase to "typer" - // intp beSilentDuring phaseCommand("typer") - } - } - - def asyncEcho(async: Boolean, msg: => String) { - if (async) asyncMessage(msg) - else echo(msg) - } - - def verbosity() = { - val old = intp.printResults - intp.printResults = !old - echo("Switched " + (if (old) "off" else "on") + " result printing.") - } - - /** Run one command submitted by the user. Two values are returned: - * (1) whether to keep running, (2) the line to record for replay, - * if any. */ - def command(line: String): Result = { - if (line startsWith ":") { - val cmd = line.tail takeWhile (x => !x.isWhitespace) - uniqueCommand(cmd) match { - case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) - case _ => ambiguousError(cmd) - } - } - else if (intp.global == null) Result(keepRunning = false, None) // Notice failure to create compiler - else Result(keepRunning = true, interpretStartingWith(line)) - } - - private def readWhile(cond: String => Boolean) = { - Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) - } - - def pasteCommand(arg: String): Result = { - var shouldReplay: Option[String] = None - def result = Result(keepRunning = true, shouldReplay) - val (raw, file) = - if (arg.isEmpty) (false, None) - else { - val r = """(-raw)?(\s+)?([^\-]\S*)?""".r - arg match { - case r(flag, sep, name) => - if (flag != null && name != null && sep == null) - echo(s"""I assume you mean "$flag $name"?""") - (flag != null, Option(name)) - case _ => - echo("usage: :paste -raw file") - return result - } - } - val code = file match { - case Some(name) => - withFile(name)(f => { - shouldReplay = Some(s":paste $arg") - val s = f.slurp.trim - if (s.isEmpty) echo(s"File contains no code: $f") - else echo(s"Pasting file $f...") - s - }) getOrElse "" - case None => - echo("// Entering paste mode (ctrl-D to finish)\n") - val text = (readWhile(_ => true) mkString "\n").trim - if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n") - else echo("\n// Exiting paste mode, now interpreting.\n") - text - } - def interpretCode() = { - val res = intp interpret code - // if input is incomplete, let the compiler try to say why - if (res == IR.Incomplete) { - echo("The pasted code is incomplete!\n") - // Remembrance of Things Pasted in an object - val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") - if (errless) echo("...but compilation found no error? Good luck with that.") - } - } - def compileCode() = { - val errless = intp compileSources new BatchSourceFile("", code) - if (!errless) echo("There were compilation errors!") - } - if (code.nonEmpty) { - if (raw) compileCode() else interpretCode() - } - result - } - - private object paste extends Pasted { - val ContinueString = " | " - val PromptString = "scala> " - - def interpret(line: String): Unit = { - echo(line.trim) - intp interpret line - echo("") - } - - def transcript(start: String) = { - echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") - apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) - } - } - import paste.{ ContinueString, PromptString } - - /** Interpret expressions starting with the first line. - * Read lines until a complete compilation unit is available - * or until a syntax error has been seen. If a full unit is - * read, go ahead and interpret it. Return the full string - * to be recorded for replay, if any. - */ - def interpretStartingWith(code: String): Option[String] = { - // signal completion non-completion input has been received - in.completion.resetVerbosity() - - def reallyInterpret = { - val reallyResult = intp.interpret(code) - (reallyResult, reallyResult match { - case IR.Error => None - case IR.Success => Some(code) - case IR.Incomplete => - if (in.interactive && code.endsWith("\n\n")) { - echo("You typed two blank lines. Starting a new command.") - None - } - else in.readLine(ContinueString) match { - case null => - // we know compilation is going to fail since we're at EOF and the - // parser thinks the input is still incomplete, but since this is - // a file being read non-interactively we want to fail. So we send - // it straight to the compiler for the nice error message. - intp.compileString(code) - None - - case line => interpretStartingWith(code + "\n" + line) - } - }) - } - - /** Here we place ourselves between the user and the interpreter and examine - * the input they are ostensibly submitting. We intervene in several cases: - * - * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. - * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation - * on the previous result. - * 3) If the Completion object's execute returns Some(_), we inject that value - * and avoid the interpreter, as it's likely not valid scala code. - */ - if (code == "") None - else if (!paste.running && code.trim.startsWith(PromptString)) { - paste.transcript(code) - None - } - else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { - interpretStartingWith(intp.mostRecentVar + code) - } - else if (code.trim startsWith "//") { - // line comment, do nothing - None - } - else - reallyInterpret._2 - } - - // runs :load `file` on any files passed via -i - def loadFiles(settings: Settings) = settings match { - case settings: GenericRunnerSettings => - for (filename <- settings.loadfiles.value) { - val cmd = ":load " + filename - command(cmd) - addReplay(cmd) - echo("") - } - case _ => - } - - /** Tries to create a JLineReader, falling back to SimpleReader: - * unless settings or properties are such that it should start - * with SimpleReader. - */ - def chooseReader(settings: Settings): InteractiveReader = { - if (settings.Xnojline || Properties.isEmacsShell) - SimpleReader() - else try new JLineReader( - if (settings.noCompletion) NoCompletion - else new SparkJLineCompletion(intp) - ) - catch { - case ex @ (_: Exception | _: NoClassDefFoundError) => - echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.") - SimpleReader() - } - } - protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = - u.TypeTag[T]( - m, - new TypeCreator { - def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type = - m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] - }) - - private def loopPostInit() { - // Bind intp somewhere out of the regular namespace where - // we can get at it in generated code. - intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain])) - // Auto-run code via some setting. - ( replProps.replAutorunCode.option - flatMap (f => io.File(f).safeSlurp()) - foreach (intp quietRun _) - ) - // classloader and power mode setup - intp.setContextClassLoader() - if (isReplPower) { - // replProps.power setValue true - // unleashAndSetPhase() - // asyncMessage(power.banner) - } - // SI-7418 Now, and only now, can we enable TAB completion. - in match { - case x: JLineReader => x.consoleReader.postInit - case _ => - } - } - def process(settings: Settings): Boolean = savingContextLoader { - this.settings = settings - createInterpreter() - - // sets in to some kind of reader depending on environmental cues - in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) - globalFuture = future { - intp.initializeSynchronous() - loopPostInit() - !intp.reporter.hasErrors - } - import scala.concurrent.duration._ - Await.ready(globalFuture, 10 seconds) - printWelcome() + override def commands: List[LoopCommand] = sparkStandardCommands + + /** + * We override `loadFiles` because we need to initialize Spark *before* the REPL + * sees any files, so that the Spark context is visible in those files. This is a bit of a + * hack, but there isn't another hook available to us at this point. + */ + override def loadFiles(settings: Settings): Unit = { initializeSpark() - loadFiles(settings) - - try loop() - catch AbstractOrMissingHandler() - finally closeInterpreter() - - true + super.loadFiles(settings) } - - @deprecated("Use `process` instead", "2.9.0") - def main(settings: Settings): Unit = process(settings) //used by sbt } object SparkILoop { - implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp - // Designed primarily for use by test code: take a String with a - // bunch of code, and prints out a transcript of what it would look - // like if you'd just typed it into the repl. - def runForTranscript(code: String, settings: Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { - override def write(str: String) = { - // completely skip continuation lines - if (str forall (ch => ch.isWhitespace || ch == '|')) () - else super.write(str) - } - } - val input = new BufferedReader(new StringReader(code.trim + "\n")) { - override def readLine(): String = { - val s = super.readLine() - // helping out by printing the line being interpreted. - if (s != null) - // scalastyle:off println - output.println(s) - // scalastyle:on println - s - } - } - val repl = new SparkILoop(input, output) - if (settings.classpath.isDefault) - settings.classpath.value = sys.props("java.class.path") - - repl process settings - } - } - } - - /** Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ + /** + * Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ def run(code: String, sets: Settings = new Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } stringFromStream { ostream => Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new SparkILoop(input, output) + val input = new BufferedReader(new StringReader(code)) + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) if (sets.classpath.isDefault) sets.classpath.value = sys.props("java.class.path") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala deleted file mode 100644 index 56c009a4e38e7..0000000000000 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ /dev/null @@ -1,1323 +0,0 @@ -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Martin Odersky - */ - -package scala -package tools.nsc -package interpreter - -import PartialFunction.cond -import scala.language.implicitConversions -import scala.beans.BeanProperty -import scala.collection.mutable -import scala.concurrent.{ Future, ExecutionContext } -import scala.reflect.runtime.{ universe => ru } -import scala.reflect.{ ClassTag, classTag } -import scala.reflect.internal.util.{ BatchSourceFile, SourceFile } -import scala.tools.util.PathResolver -import scala.tools.nsc.io.AbstractFile -import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings } -import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps } -import scala.tools.nsc.util.Exceptional.unwrap -import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable} - -/** An interpreter for Scala code. - * - * The main public entry points are compile(), interpret(), and bind(). - * The compile() method loads a complete Scala file. The interpret() method - * executes one line of Scala code at the request of the user. The bind() - * method binds an object to a variable that can then be used by later - * interpreted code. - * - * The overall approach is based on compiling the requested code and then - * using a Java classloader and Java reflection to run the code - * and access its results. - * - * In more detail, a single compiler instance is used - * to accumulate all successfully compiled or interpreted Scala code. To - * "interpret" a line of code, the compiler generates a fresh object that - * includes the line of code and which has public member(s) to export - * all variables defined by that code. To extract the result of an - * interpreted line to show the user, a second "result object" is created - * which imports the variables exported by the above object and then - * exports members called "$eval" and "$print". To accomodate user expressions - * that read from variables or methods defined in previous statements, "import" - * statements are used. - * - * This interpreter shares the strengths and weaknesses of using the - * full compiler-to-Java. The main strength is that interpreted code - * behaves exactly as does compiled code, including running at full speed. - * The main weakness is that redefining classes and methods is not handled - * properly, because rebinding at the Java level is technically difficult. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - */ -class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings, - protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports { - imain => - - setBindings(createBindings, ScriptContext.ENGINE_SCOPE) - object replOutput extends ReplOutput(settings.Yreploutdir) { } - - @deprecated("Use replOutput.dir instead", "2.11.0") - def virtualDirectory = replOutput.dir - // Used in a test case. - def showDirectory() = replOutput.show(out) - - private[nsc] var printResults = true // whether to print result lines - private[nsc] var totalSilence = false // whether to print anything - private var _initializeComplete = false // compiler is initialized - private var _isInitialized: Future[Boolean] = null // set up initialization future - private var bindExceptions = true // whether to bind the lastException variable - private var _executionWrapper = "" // code to be wrapped around all lines - - /** We're going to go to some trouble to initialize the compiler asynchronously. - * It's critical that nothing call into it until it's been initialized or we will - * run into unrecoverable issues, but the perceived repl startup time goes - * through the roof if we wait for it. So we initialize it with a future and - * use a lazy val to ensure that any attempt to use the compiler object waits - * on the future. - */ - private var _classLoader: util.AbstractFileClassLoader = null // active classloader - private val _compiler: ReplGlobal = newCompiler(settings, reporter) // our private compiler - - def compilerClasspath: Seq[java.net.URL] = ( - if (isInitializeComplete) global.classPath.asURLs - else new PathResolver(settings).result.asURLs // the compiler's classpath - ) - def settings = initialSettings - // Run the code body with the given boolean settings flipped to true. - def withoutWarnings[T](body: => T): T = beQuietDuring { - val saved = settings.nowarn.value - if (!saved) - settings.nowarn.value = true - - try body - finally if (!saved) settings.nowarn.value = false - } - - /** construct an interpreter that reports to Console */ - def this(settings: Settings, out: JPrintWriter) = this(null, settings, out) - def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this(factory: ScriptEngineFactory) = this(factory, new Settings()) - def this() = this(new Settings()) - - lazy val formatting: Formatting = new Formatting { - val prompt = Properties.shellPromptString - } - lazy val reporter: SparkReplReporter = new SparkReplReporter(this) - - import formatting._ - import reporter.{ printMessage, printUntruncatedMessage } - - // This exists mostly because using the reporter too early leads to deadlock. - private def echo(msg: String) { Console println msg } - private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) - private def _initialize() = { - try { - // if this crashes, REPL will hang its head in shame - val run = new _compiler.Run() - assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") - run compileSources _initSources - _initializeComplete = true - true - } - catch AbstractOrMissingHandler() - } - private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" - private val logScope = scala.sys.props contains "scala.repl.scope" - // scalastyle:off println - private def scopelog(msg: String) = if (logScope) Console.err.println(msg) - // scalastyle:on println - - // argument is a thunk to execute after init is done - def initialize(postInitSignal: => Unit) { - synchronized { - if (_isInitialized == null) { - _isInitialized = - Future(try _initialize() finally postInitSignal)(ExecutionContext.global) - } - } - } - def initializeSynchronous(): Unit = { - if (!isInitializeComplete) { - _initialize() - assert(global != null, global) - } - } - def isInitializeComplete = _initializeComplete - - lazy val global: Global = { - if (!isInitializeComplete) _initialize() - _compiler - } - - import global._ - import definitions.{ ObjectClass, termMember, dropNullaryMethod} - - lazy val runtimeMirror = ru.runtimeMirror(classLoader) - - private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol } - - def getClassIfDefined(path: String) = ( - noFatal(runtimeMirror staticClass path) - orElse noFatal(rootMirror staticClass path) - ) - def getModuleIfDefined(path: String) = ( - noFatal(runtimeMirror staticModule path) - orElse noFatal(rootMirror staticModule path) - ) - - implicit class ReplTypeOps(tp: Type) { - def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) - } - - // TODO: If we try to make naming a lazy val, we run into big time - // scalac unhappiness with what look like cycles. It has not been easy to - // reduce, but name resolution clearly takes different paths. - object naming extends { - val global: imain.global.type = imain.global - } with Naming { - // make sure we don't overwrite their unwisely named res3 etc. - def freshUserTermName(): TermName = { - val name = newTermName(freshUserVarName()) - if (replScope containsName name) freshUserTermName() - else name - } - def isInternalTermName(name: Name) = isInternalVarName("" + name) - } - import naming._ - - object deconstruct extends { - val global: imain.global.type = imain.global - } with StructuredTypeStrings - - lazy val memberHandlers = new { - val intp: imain.type = imain - } with SparkMemberHandlers - import memberHandlers._ - - /** Temporarily be quiet */ - def beQuietDuring[T](body: => T): T = { - val saved = printResults - printResults = false - try body - finally printResults = saved - } - def beSilentDuring[T](operation: => T): T = { - val saved = totalSilence - totalSilence = true - try operation - finally totalSilence = saved - } - - def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - /** takes AnyRef because it may be binding a Throwable or an Exceptional */ - private def withLastExceptionLock[T](body: => T, alt: => T): T = { - assert(bindExceptions, "withLastExceptionLock called incorrectly.") - bindExceptions = false - - try beQuietDuring(body) - catch logAndDiscard("withLastExceptionLock", alt) - finally bindExceptions = true - } - - def executionWrapper = _executionWrapper - def setExecutionWrapper(code: String) = _executionWrapper = code - def clearExecutionWrapper() = _executionWrapper = "" - - /** interpreter settings */ - lazy val isettings = new SparkISettings(this) - - /** Instantiate a compiler. Overridable. */ - protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = { - settings.outputDirs setSingleOutput replOutput.dir - settings.exposeEmptyPackage.value = true - new Global(settings, reporter) with ReplGlobal { override def toString: String = "" } - } - - /** Parent classloader. Overridable. */ - protected def parentClassLoader: ClassLoader = - settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() ) - - /* A single class loader is used for all commands interpreted by this Interpreter. - It would also be possible to create a new class loader for each command - to interpret. The advantages of the current approach are: - - - Expressions are only evaluated one time. This is especially - significant for I/O, e.g. "val x = Console.readLine" - - The main disadvantage is: - - - Objects, classes, and methods cannot be rebound. Instead, definitions - shadow the old ones, and old code objects refer to the old - definitions. - */ - def resetClassLoader() = { - repldbg("Setting new classloader: was " + _classLoader) - _classLoader = null - ensureClassLoader() - } - final def ensureClassLoader() { - if (_classLoader == null) - _classLoader = makeClassLoader() - } - def classLoader: util.AbstractFileClassLoader = { - ensureClassLoader() - _classLoader - } - - def backticked(s: String): String = ( - (s split '.').toList map { - case "_" => "_" - case s if nme.keywords(newTermName(s)) => s"`$s`" - case s => s - } mkString "." - ) - def readRootPath(readPath: String) = getModuleIfDefined(readPath) - - abstract class PhaseDependentOps { - def shift[T](op: => T): T - - def path(name: => Name): String = shift(path(symbolOfName(name))) - def path(sym: Symbol): String = backticked(shift(sym.fullName)) - def sig(sym: Symbol): String = shift(sym.defString) - } - object typerOp extends PhaseDependentOps { - def shift[T](op: => T): T = exitingTyper(op) - } - object flatOp extends PhaseDependentOps { - def shift[T](op: => T): T = exitingFlatten(op) - } - - def originalPath(name: String): String = originalPath(name: TermName) - def originalPath(name: Name): String = typerOp path name - def originalPath(sym: Symbol): String = typerOp path sym - def flatPath(sym: Symbol): String = flatOp shift sym.javaClassName - def translatePath(path: String) = { - val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path) - sym.toOption map flatPath - } - def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath - - private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) { - /** Overridden here to try translating a simple name to the generated - * class name if the original attempt fails. This method is used by - * getResourceAsStream as well as findClass. - */ - override protected def findAbstractFile(name: String): AbstractFile = - super.findAbstractFile(name) match { - case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull - case file => file - } - } - private def makeClassLoader(): util.AbstractFileClassLoader = - new TranslatingClassLoader(parentClassLoader match { - case null => ScalaClassLoader fromURLs compilerClasspath - case p => new ScalaClassLoader.URLClassLoader(compilerClasspath, p) - }) - - // Set the current Java "context" class loader to this interpreter's class loader - def setContextClassLoader() = classLoader.setAsContext() - - def allDefinedNames: List[Name] = exitingTyper(replScope.toList.map(_.name).sorted) - def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted - - /** Most recent tree handled which wasn't wholly synthetic. */ - private def mostRecentlyHandledTree: Option[Tree] = { - prevRequests.reverse foreach { req => - req.handlers.reverse foreach { - case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) - case _ => () - } - } - None - } - - private def updateReplScope(sym: Symbol, isDefined: Boolean) { - def log(what: String) { - val mark = if (sym.isType) "t " else "v " - val name = exitingTyper(sym.nameString) - val info = cleanTypeAfterTyper(sym) - val defn = sym defStringSeenAs info - - scopelog(f"[$mark$what%6s] $name%-25s $defn%s") - } - if (ObjectClass isSubClass sym.owner) return - // unlink previous - replScope lookupAll sym.name foreach { sym => - log("unlink") - replScope unlink sym - } - val what = if (isDefined) "define" else "import" - log(what) - replScope enter sym - } - - def recordRequest(req: Request) { - if (req == null) - return - - prevRequests += req - - // warning about serially defining companions. It'd be easy - // enough to just redefine them together but that may not always - // be what people want so I'm waiting until I can do it better. - exitingTyper { - req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym => - val oldSym = replScope lookup newSym.name.companionName - if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) { - replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.") - replwarn("Companions must be defined together; you may wish to use :paste mode for this.") - } - } - } - exitingTyper { - req.imports foreach (sym => updateReplScope(sym, isDefined = false)) - req.defines foreach (sym => updateReplScope(sym, isDefined = true)) - } - } - - private[nsc] def replwarn(msg: => String) { - if (!settings.nowarnings) - printMessage(msg) - } - - def compileSourcesKeepingRun(sources: SourceFile*) = { - val run = new Run() - assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") - reporter.reset() - run compileSources sources.toList - (!reporter.hasErrors, run) - } - - /** Compile an nsc SourceFile. Returns true if there are - * no compilation errors, or false otherwise. - */ - def compileSources(sources: SourceFile*): Boolean = - compileSourcesKeepingRun(sources: _*)._1 - - /** Compile a string. Returns true if there are no - * compilation errors, or false otherwise. - */ - def compileString(code: String): Boolean = - compileSources(new BatchSourceFile("