diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ec0b3c78ed72c..703ea4d1498ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1178,32 +1178,6 @@ class Dataset[T] private[sql]( withGroupingKey.newColumns) } - /** - * :: Experimental :: - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given [[Column]] - * expressions. - * - * @group typedrel - * @since 2.0.0 - */ - @Experimental - @scala.annotation.varargs - def groupByKey(cols: Column*): KeyValueGroupedDataset[Row, T] = { - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) - val withKey = Project(withKeyColumns, logicalPlan) - val executed = sqlContext.executePlan(withKey) - - val dataAttributes = executed.analyzed.output.dropRight(cols.size) - val keyAttributes = executed.analyzed.output.takeRight(cols.size) - - new KeyValueGroupedDataset( - RowEncoder(keyAttributes.toStructType), - encoderFor[T], - executed, - dataAttributes, - keyAttributes) - } - /** * :: Experimental :: * (Java-specific) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 18f17a85a9dbd..86db8df4c00fd 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -245,29 +245,6 @@ public Iterator call(Integer key, Iterator left, Iterator data = Arrays.asList("a", "foo", "bar"); - Dataset ds = context.createDataset(data, Encoders.STRING()); - KeyValueGroupedDataset grouped = - ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); - - Dataset mapped = grouped.mapGroups( - new MapGroupsFunction() { - @Override - public String call(Integer key, Iterator data) throws Exception { - StringBuilder sb = new StringBuilder(key.toString()); - while (data.hasNext()) { - sb.append(data.next()); - } - return sb.toString(); - } - }, - Encoders.STRING()); - - Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); - } - @Test public void testSelect() { List data = Arrays.asList(2, 6); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 2e5179a8d2c95..942cc09b6d58e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -63,7 +63,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { test("persist and then groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1").keyAs[String] + val grouped = ds.groupByKey(_._1) val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } agged.persist() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0bcc512d7137d..553bc436a6456 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -322,55 +322,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ) } - test("groupBy columns, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1") - val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } - - checkDataset( - agged, - ("a", 30), ("b", 3), ("c", 1)) - } - - test("groupBy columns, count") { - val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() - val count = ds.groupByKey($"_1").count() - - checkDataset( - count, - (Row("a"), 2L), (Row("b"), 1L)) - } - - test("groupBy columns asKey, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1").keyAs[String] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - - checkDataset( - agged, - ("a", 30), ("b", 3), ("c", 1)) - } - - test("groupBy columns asKey tuple, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1", lit(1)).keyAs[(String, Int)] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - - checkDataset( - agged, - (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) - } - - test("groupBy columns asKey class, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - - checkDataset( - agged, - (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) - } - test("typed aggregation: expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()