Skip to content

Commit

Permalink
[SPARK-14142][SQL] Replace internal use of unionAll with union
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
unionAll has been deprecated in SPARK-14088.

## How was this patch tested?
Should be covered by all existing tests.

Author: Reynold Xin <rxin@databricks.com>

Closes apache#11946 from rxin/SPARK-14142.
  • Loading branch information
rxin committed Mar 25, 2016
1 parent 13cbb2d commit 3619fec
Show file tree
Hide file tree
Showing 21 changed files with 45 additions and 45 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def repartition(self, numPartitions, *cols):
>>> df.repartition(10).rdd.getNumPartitions()
10
>>> data = df.unionAll(df).repartition("age")
>>> data = df.union(df).repartition("age")
>>> data.show()
+---+-----+
|age| name|
Expand Down Expand Up @@ -919,7 +919,7 @@ def union(self, other):
This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union
(that does deduplication of elements), use this function followed by a distinct.
"""
return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
return DataFrame(self._jdf.union(other._jdf), self.sql_ctx)

@since(1.3)
def unionAll(self, other):
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def test_parquet_with_udt(self):
point = df1.head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))

def test_unionAll_with_udt(self):
def test_union_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row1 = (1.0, ExamplePoint(1.0, 2.0))
row2 = (2.0, ExamplePoint(3.0, 4.0))
Expand All @@ -608,7 +608,7 @@ def test_unionAll_with_udt(self):
df1 = self.sqlCtx.createDataFrame([row1], schema)
df2 = self.sqlCtx.createDataFrame([row2], schema)

result = df1.unionAll(df2).orderBy("label").collect()
result = df1.union(df2).orderBy("label").collect()
self.assertEqual(
result,
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ package object dsl {

def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan)

def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)
def union(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)

def generate(
generator: Generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class AnalysisErrorSuite extends AnalysisTest {

errorTest(
"union with unequal number of columns",
testRelation.unionAll(testRelation2),
testRelation.union(testRelation2),
"union" :: "number of columns" :: testRelation2.output.length.toString ::
testRelation.output.length.toString :: Nil)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class AnalysisSuite extends AnalysisTest {
val plan = (1 to 100)
.map(_ => testRelation)
.fold[LogicalPlan](testRelation) { (a, b) =>
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
a.select(UnresolvedStar(None)).select('a).union(b.select(UnresolvedStar(None)))
}

assertAnalysisSuccess(plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class PruneFiltersSuite extends PlanTest {

val query =
tr1.where('a.attr > 10)
.unionAll(tr2.where('d.attr > 10)
.unionAll(tr3.where('g.attr > 10)))
.union(tr2.where('d.attr > 10)
.union(tr3.where('g.attr > 10)))
val queryWithUselessFilter = query.where('a.attr > 10)

val optimized = Optimize.execute(queryWithUselessFilter.analyze)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ class ConstraintPropagationSuite extends SparkFunSuite {

assert(tr1
.where('a.attr > 10)
.unionAll(tr2.where('e.attr > 10)
.unionAll(tr3.where('i.attr > 10)))
.union(tr2.where('e.attr > 10)
.union(tr3.where('i.attr > 10)))
.analyze.constraints.isEmpty)

verifyConstraints(tr1
.where('a.attr > 10)
.unionAll(tr2.where('d.attr > 10)
.unionAll(tr3.where('g.attr > 10)))
.union(tr2.where('d.attr > 10)
.union(tr3.where('g.attr > 10)))
.analyze.constraints,
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
newBlocks
.map(_.toDF())
.reduceOption(_ unionAll _)
.reduceOption(_ union _)
.getOrElse {
sys.error("No data selected!")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
}

test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
val table3x = testData.unionAll(testData).unionAll(testData)
val table3x = testData.union(testData).union(testData)
table3x.registerTempTable("testData3x")

sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
assert(splits.reduce((a, b) => a.union(b)).sort("id").collect().toList ==
data.collect().toList, "incomplete or wrong split")

val s = splits.map(_.count())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}

test("union all") {
val unionDF = testData.unionAll(testData).unionAll(testData)
.unionAll(testData).unionAll(testData)
val unionDF = testData.union(testData).union(testData)
.union(testData).union(testData)

// Before optimizer, Union should be combined.
assert(unionDF.queryExecution.analyzed.collect {
Expand All @@ -107,7 +107,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
)
}

test("unionAll should union DataFrames with UDTs (SPARK-13410)") {
test("union should union DataFrames with UDTs (SPARK-13410)") {
val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0))))
val schema1 = StructType(Array(StructField("label", IntegerType, false),
StructField("point", new ExamplePointUDT(), false)))
Expand All @@ -118,7 +118,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df2 = sqlContext.createDataFrame(rowRDD2, schema2)

checkAnswer(
df1.unionAll(df2).orderBy("label"),
df1.union(df2).orderBy("label"),
Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0)))
)
}
Expand Down Expand Up @@ -636,7 +636,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val jsonDF = sqlContext.read.json(jsonDir)
assert(parquetDF.inputFiles.nonEmpty)

val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted
val unioned = jsonDF.union(parquetDF).inputFiles.sorted
val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).distinct.sorted
assert(unioned === allFiles)
}
Expand Down Expand Up @@ -1104,7 +1104,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}

val union = df1.unionAll(df2)
val union = df1.union(df2)
checkAnswer(
union.filter('i < rand(7) * 10),
expected(union)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}

test("big inner join, 4 matches per row") {
val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
val bigData = testData.union(testData).union(testData).union(testData)
val bigDataX = bigData.as("x")
val bigDataY = bigData.as("y")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("aggregation with codegen") {
// Prepare a table that we can group some rows.
sqlContext.table("testData")
.unionAll(sqlContext.table("testData"))
.unionAll(sqlContext.table("testData"))
.union(sqlContext.table("testData"))
.union(sqlContext.table("testData"))
.registerTempTable("testData3x")

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
sqlContext
.range(0, 1000)
.selectExpr("id % 500 as key", "id as value")
.unionAll(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
.union(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
checkAnswer(
join,
expectedAnswer.collect())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
path.delete()

val base = sqlContext.range(100)
val df = base.unionAll(base).select($"id", lit(1).as("data"))
val df = base.union(base).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)

checkAnswer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA

// verify the append mode
df.write.mode(SaveMode.Append).json(path.toString)
val df2 = df.unionAll(df)
val df2 = df.union(df)
df2.registerTempTable("jsonTable2")

checkLoad(df2, "jsonTable2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void saveTableAndQueryIt() {

@Test
public void testUDAF() {
Dataset<Row> df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value"));
Dataset<Row> df = hc.range(0, 100).union(hc.range(0, 100)).select(col("id").as("value"));
UserDefinedAggregateFunction udaf = new MyDoubleSum();
UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf);
// Create Columns for the UDAF. For now, callUDF does not take an argument to specific if
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton {
assertCached(table("refreshTable"))
checkAnswer(
table("refreshTable"),
table("src").unionAll(table("src")).collect())
table("src").union(table("src")).collect())

// Drop the table and create it again.
sql("DROP TABLE refreshTable")
Expand All @@ -198,7 +198,7 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton {
sql("REFRESH TABLE refreshTable")
checkAnswer(
table("refreshTable"),
table("src").unionAll(table("src")).collect())
table("src").union(table("src")).collect())
// It is not cached.
assert(!isCached("refreshTable"), "refreshTable should not be cached.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
df.write.mode(SaveMode.Append).saveAsTable("t")
assert(sqlContext.tableNames().contains("t"))
checkAnswer(sqlContext.table("t"), df.unionAll(df))
checkAnswer(sqlContext.table("t"), df.union(df))
}

assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkAnswer(sqlContext.table(s"$db.t"), df.union(df))

checkTablePath(db, "t")
}
Expand All @@ -128,7 +128,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
df.write.mode(SaveMode.Append).saveAsTable(s"$db.t")
assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkAnswer(sqlContext.table(s"$db.t"), df.union(df))

checkTablePath(db, "t")
}
Expand All @@ -141,7 +141,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
assert(sqlContext.tableNames().contains("t"))

df.write.insertInto(s"$db.t")
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkAnswer(sqlContext.table(s"$db.t"), df.union(df))
}
}
}
Expand All @@ -156,7 +156,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
assert(sqlContext.tableNames(db).contains("t"))

df.write.insertInto(s"$db.t")
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkAnswer(sqlContext.table(s"$db.t"), df.union(df))
}
}

Expand Down Expand Up @@ -220,7 +220,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
hiveContext.refreshTable("t")
checkAnswer(
sqlContext.table("t"),
df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2))))
df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2))))
}
}
}
Expand Down Expand Up @@ -252,7 +252,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
hiveContext.refreshTable(s"$db.t")
checkAnswer(
sqlContext.table(s"$db.t"),
df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2))))
df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2))))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
fs.delete(commonSummaryPath, true)

df.write.mode(SaveMode.Append).parquet(path)
checkAnswer(sqlContext.read.parquet(path), df.unionAll(df))
checkAnswer(sqlContext.read.parquet(path), df.union(df))

assert(fs.exists(summaryPath))
assert(fs.exists(commonSummaryPath))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
p2 <- Seq("foo", "bar")
} yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2")

lazy val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2)
lazy val partitionedTestDF = partitionedTestDF1.union(partitionedTestDF2)

def checkQueries(df: DataFrame): Unit = {
// Selects everything
Expand Down Expand Up @@ -191,7 +191,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
sqlContext.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath).orderBy("a"),
testDF.unionAll(testDF).orderBy("a").collect())
testDF.union(testDF).orderBy("a").collect())
}
}

Expand Down Expand Up @@ -268,7 +268,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
sqlContext.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath),
partitionedTestDF.unionAll(partitionedTestDF).collect())
partitionedTestDF.union(partitionedTestDF).collect())
}
}

Expand Down Expand Up @@ -332,7 +332,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t")

withTable("t") {
checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect())
checkAnswer(sqlContext.table("t"), testDF.union(testDF).orderBy("a").collect())
}
}

Expand Down Expand Up @@ -415,7 +415,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.saveAsTable("t")

withTable("t") {
checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect())
checkAnswer(sqlContext.table("t"), partitionedTestDF.union(partitionedTestDF).collect())
}
}

Expand Down Expand Up @@ -625,7 +625,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.format(dataSourceName)
.option("dataSchema", df.schema.json)
.load(dir.getCanonicalPath),
df.unionAll(df))
df.union(df))

// This will fail because AlwaysFailOutputCommitter is used when we do append.
intercept[Exception] {
Expand Down

0 comments on commit 3619fec

Please sign in to comment.