forked from eldersantos/community
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
180 additions
and
0 deletions.
There are no files selected for viewing
82 changes: 82 additions & 0 deletions
82
cypher/src/main/scala/org/neo4j/cypher/internal/pipes/OrderedAggregationPipe.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package org.neo4j.cypher.internal.pipes | ||
|
||
/** | ||
* Copyright (c) 2002-2011 "Neo Technology," | ||
* Network Engine for Objects in Lund AB [http://neotechnology.com] | ||
* | ||
* This file is part of Neo4j. | ||
* | ||
* Neo4j is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
*/ | ||
|
||
import collection.Seq | ||
import org.neo4j.cypher.commands.{AggregationItem, ReturnItem} | ||
import java.lang.String | ||
import org.neo4j.cypher.symbols.{Identifier, SymbolTable} | ||
import org.neo4j.cypher.pipes.aggregation.AggregationFunction | ||
import org.neo4j.helpers.ThisShouldNotHappenError | ||
|
||
// This class can be used to aggregate if the values sub graphs come in the order that they are keyed on | ||
class OrderedAggregationPipe(source: Pipe, val returnItems: Seq[ReturnItem], aggregations: Seq[AggregationItem]) extends PipeWithSource(source) { | ||
|
||
if (returnItems.isEmpty) | ||
throw new ThisShouldNotHappenError("Andres Taylor", "The ordered aggregation pipe should never be used without aggregation keys") | ||
|
||
val symbols: SymbolTable = createSymbols() | ||
|
||
def dependencies: Seq[Identifier] = returnItems.flatMap(_.dependencies) ++ aggregations.flatMap(_.dependencies) | ||
|
||
def createSymbols() = { | ||
val keySymbols = source.symbols.filter(returnItems.map(_.columnName): _*) | ||
val aggregatedColumns = aggregations.map(_.concreteReturnItem.identifier) | ||
|
||
keySymbols.add(aggregatedColumns: _*) | ||
} | ||
|
||
def createResults[U](params: Map[String, Any]): Traversable[Map[String, Any]] = new OrderedAggregator(source.createResults(params), returnItems, aggregations) | ||
|
||
override def executionPlan(): String = source.executionPlan() + "\r\n" + "EagerAggregation( keys: [" + returnItems.map(_.columnName).mkString(", ") + "], aggregates: [" + aggregations.mkString(", ") + "])" | ||
} | ||
|
||
private class OrderedAggregator(source: Traversable[Map[String, Any]], | ||
returnItems: Seq[ReturnItem], | ||
aggregations: Seq[AggregationItem]) extends Traversable[Map[String, Any]] { | ||
var currentKey: Option[Seq[Any]] = None | ||
var aggregationSpool: Seq[AggregationFunction] = null | ||
val keyColumns = returnItems.map(_.columnName) | ||
val aggregateColumns = aggregations.map(_.columnName) | ||
|
||
def getIntermediateResults[U]() = (keyColumns.zip(currentKey.get) ++ aggregateColumns.zip(aggregationSpool.map(_.result))).toMap | ||
|
||
def foreach[U](f: (Map[String, Any]) => U) { | ||
source.foreach(m => { | ||
val key = Some(returnItems.map(_.apply(m))) | ||
if (currentKey.isEmpty) { | ||
aggregationSpool = aggregations.map(_.createAggregationFunction) | ||
currentKey = key | ||
} else if (key != currentKey) { | ||
f(getIntermediateResults()) | ||
|
||
aggregationSpool = aggregations.map(_.createAggregationFunction) | ||
currentKey = key | ||
} | ||
|
||
aggregationSpool.foreach(func => func(m)) | ||
}) | ||
|
||
if (currentKey.nonEmpty) { | ||
f(getIntermediateResults()) | ||
} | ||
} | ||
} |
98 changes: 98 additions & 0 deletions
98
cypher/src/test/scala/org/neo4j/cypher/internal/pipes/OrderedAggregationPipeTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package org.neo4j.cypher.internal.pipes | ||
|
||
/** | ||
* Copyright (c) 2002-2011 "Neo Technology," | ||
* Network Engine for Objects in Lund AB [http://neotechnology.com] | ||
* | ||
* This file is part of Neo4j. | ||
* | ||
* Neo4j is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
*/ | ||
|
||
import org.junit.Test | ||
import org.junit.Assert._ | ||
import org.junit.matchers.JUnitMatchers._ | ||
import scala.collection.JavaConverters._ | ||
import org.neo4j.cypher.commands._ | ||
import org.scalatest.junit.JUnitSuite | ||
import org.neo4j.cypher.symbols.{IntegerType, SymbolTable, Identifier, NodeType} | ||
import org.scalatest.Assertions | ||
import org.neo4j.cypher.SyntaxException | ||
import org.neo4j.helpers.ThisShouldNotHappenError | ||
|
||
class OrderedAggregationPipeTest extends JUnitSuite with Assertions { | ||
@Test def shouldReturnColumnsFromReturnItems() { | ||
val source = new FakePipe(List(), createSymbolTableFor("name")) | ||
|
||
val returnItems = List(ExpressionReturnItem(Entity("name"))) | ||
val grouping = List(CountStar()) | ||
val aggregationPipe = new OrderedAggregationPipe(source, returnItems, grouping) | ||
|
||
assertEquals( | ||
Seq(Identifier("name", NodeType()), Identifier("count(*)", IntegerType())), | ||
aggregationPipe.symbols.identifiers) | ||
} | ||
|
||
@Test def shouldThrowSemanticException() { | ||
val source = new FakePipe(List(), createSymbolTableFor("extractReturnItems")) | ||
|
||
val returnItems = List(ExpressionReturnItem(Entity("name"))) | ||
val grouping = List(ValueAggregationItem(Count(Entity("none-existing-identifier")))) | ||
intercept[SyntaxException](new OrderedAggregationPipe(source, returnItems, grouping)) | ||
} | ||
|
||
@Test def shouldAggregateCountStar() { | ||
val source = new FakePipe(List( | ||
Map("name" -> "Andres", "age" -> 36), | ||
Map("name" -> "Michael", "age" -> 36), | ||
Map("name" -> "Michael", "age" -> 31), | ||
Map("name" -> "Peter", "age" -> 38) | ||
), createSymbolTableFor("name")) | ||
|
||
val returnItems = List(ExpressionReturnItem(Entity("name"))) | ||
val grouping = List(CountStar()) | ||
val aggregationPipe = new OrderedAggregationPipe(source, returnItems, grouping) | ||
|
||
assertThat(aggregationPipe.createResults(Map()).toIterable.asJava, hasItems( | ||
Map("name" -> "Andres", "count(*)" -> 1), | ||
Map("name" -> "Michael", "count(*)" -> 2), | ||
Map("name" -> "Peter", "count(*)" -> 1))) | ||
} | ||
|
||
@Test def shouldCountNonNullValues() { | ||
val source = new FakePipe(List( | ||
Map("name" -> "Andres", "age" -> 36), | ||
Map("name" -> "Michael", "age" -> null), | ||
Map("name" -> "Peter", "age" -> 38)), createSymbolTableFor("name", "age")) | ||
|
||
val returnItems = List(ExpressionReturnItem(Entity("name"))) | ||
val grouping = List(ValueAggregationItem(Count((Entity("age"))))) | ||
val aggregationPipe = new OrderedAggregationPipe(source, returnItems, grouping) | ||
|
||
assertThat(aggregationPipe.createResults(Map()).toIterable.asJava, hasItems( | ||
Map("name" -> "Andres", "count(age)" -> 1), | ||
Map("name" -> "Michael", "count(age)" -> 0), | ||
Map("name" -> "Peter", "count(age)" -> 1))) } | ||
|
||
@Test def shouldThrowOnEmptyKeyList() { | ||
val source = new FakePipe(List(), createSymbolTableFor("name")) | ||
|
||
val returnItems = List() | ||
val grouping = List(ValueAggregationItem(Count((Entity("name"))))) | ||
intercept[ThisShouldNotHappenError](new OrderedAggregationPipe(source, returnItems, grouping)) | ||
} | ||
|
||
private def createSymbolTableFor(names: String*) = new SymbolTable( names.map( Identifier(_, NodeType()) ):_* ) | ||
|
||
} |