Skip to content

Commit

Permalink
Added the OrderedAggregationPipe
Browse files Browse the repository at this point in the history
  • Loading branch information
systay committed Dec 14, 2011
1 parent 84b014a commit bd1851f
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 0 deletions.
@@ -0,0 +1,82 @@
package org.neo4j.cypher.internal.pipes

/**
* Copyright (c) 2002-2011 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

import collection.Seq
import org.neo4j.cypher.commands.{AggregationItem, ReturnItem}
import java.lang.String
import org.neo4j.cypher.symbols.{Identifier, SymbolTable}
import org.neo4j.cypher.pipes.aggregation.AggregationFunction
import org.neo4j.helpers.ThisShouldNotHappenError

// This class can be used to aggregate if the values sub graphs come in the order that they are keyed on
class OrderedAggregationPipe(source: Pipe, val returnItems: Seq[ReturnItem], aggregations: Seq[AggregationItem]) extends PipeWithSource(source) {

if (returnItems.isEmpty)
throw new ThisShouldNotHappenError("Andres Taylor", "The ordered aggregation pipe should never be used without aggregation keys")

val symbols: SymbolTable = createSymbols()

def dependencies: Seq[Identifier] = returnItems.flatMap(_.dependencies) ++ aggregations.flatMap(_.dependencies)

def createSymbols() = {
val keySymbols = source.symbols.filter(returnItems.map(_.columnName): _*)
val aggregatedColumns = aggregations.map(_.concreteReturnItem.identifier)

keySymbols.add(aggregatedColumns: _*)
}

def createResults[U](params: Map[String, Any]): Traversable[Map[String, Any]] = new OrderedAggregator(source.createResults(params), returnItems, aggregations)

override def executionPlan(): String = source.executionPlan() + "\r\n" + "EagerAggregation( keys: [" + returnItems.map(_.columnName).mkString(", ") + "], aggregates: [" + aggregations.mkString(", ") + "])"
}

private class OrderedAggregator(source: Traversable[Map[String, Any]],
returnItems: Seq[ReturnItem],
aggregations: Seq[AggregationItem]) extends Traversable[Map[String, Any]] {
var currentKey: Option[Seq[Any]] = None
var aggregationSpool: Seq[AggregationFunction] = null
val keyColumns = returnItems.map(_.columnName)
val aggregateColumns = aggregations.map(_.columnName)

def getIntermediateResults[U]() = (keyColumns.zip(currentKey.get) ++ aggregateColumns.zip(aggregationSpool.map(_.result))).toMap

def foreach[U](f: (Map[String, Any]) => U) {
source.foreach(m => {
val key = Some(returnItems.map(_.apply(m)))
if (currentKey.isEmpty) {
aggregationSpool = aggregations.map(_.createAggregationFunction)
currentKey = key
} else if (key != currentKey) {
f(getIntermediateResults())

aggregationSpool = aggregations.map(_.createAggregationFunction)
currentKey = key
}

aggregationSpool.foreach(func => func(m))
})

if (currentKey.nonEmpty) {
f(getIntermediateResults())
}
}
}
@@ -0,0 +1,98 @@
package org.neo4j.cypher.internal.pipes

/**
* Copyright (c) 2002-2011 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

import org.junit.Test
import org.junit.Assert._
import org.junit.matchers.JUnitMatchers._
import scala.collection.JavaConverters._
import org.neo4j.cypher.commands._
import org.scalatest.junit.JUnitSuite
import org.neo4j.cypher.symbols.{IntegerType, SymbolTable, Identifier, NodeType}
import org.scalatest.Assertions
import org.neo4j.cypher.SyntaxException
import org.neo4j.helpers.ThisShouldNotHappenError

class OrderedAggregationPipeTest extends JUnitSuite with Assertions {
@Test def shouldReturnColumnsFromReturnItems() {
val source = new FakePipe(List(), createSymbolTableFor("name"))

val returnItems = List(ExpressionReturnItem(Entity("name")))
val grouping = List(CountStar())
val aggregationPipe = new OrderedAggregationPipe(source, returnItems, grouping)

assertEquals(
Seq(Identifier("name", NodeType()), Identifier("count(*)", IntegerType())),
aggregationPipe.symbols.identifiers)
}

@Test def shouldThrowSemanticException() {
val source = new FakePipe(List(), createSymbolTableFor("extractReturnItems"))

val returnItems = List(ExpressionReturnItem(Entity("name")))
val grouping = List(ValueAggregationItem(Count(Entity("none-existing-identifier"))))
intercept[SyntaxException](new OrderedAggregationPipe(source, returnItems, grouping))
}

@Test def shouldAggregateCountStar() {
val source = new FakePipe(List(
Map("name" -> "Andres", "age" -> 36),
Map("name" -> "Michael", "age" -> 36),
Map("name" -> "Michael", "age" -> 31),
Map("name" -> "Peter", "age" -> 38)
), createSymbolTableFor("name"))

val returnItems = List(ExpressionReturnItem(Entity("name")))
val grouping = List(CountStar())
val aggregationPipe = new OrderedAggregationPipe(source, returnItems, grouping)

assertThat(aggregationPipe.createResults(Map()).toIterable.asJava, hasItems(
Map("name" -> "Andres", "count(*)" -> 1),
Map("name" -> "Michael", "count(*)" -> 2),
Map("name" -> "Peter", "count(*)" -> 1)))
}

@Test def shouldCountNonNullValues() {
val source = new FakePipe(List(
Map("name" -> "Andres", "age" -> 36),
Map("name" -> "Michael", "age" -> null),
Map("name" -> "Peter", "age" -> 38)), createSymbolTableFor("name", "age"))

val returnItems = List(ExpressionReturnItem(Entity("name")))
val grouping = List(ValueAggregationItem(Count((Entity("age")))))
val aggregationPipe = new OrderedAggregationPipe(source, returnItems, grouping)

assertThat(aggregationPipe.createResults(Map()).toIterable.asJava, hasItems(
Map("name" -> "Andres", "count(age)" -> 1),
Map("name" -> "Michael", "count(age)" -> 0),
Map("name" -> "Peter", "count(age)" -> 1))) }

@Test def shouldThrowOnEmptyKeyList() {
val source = new FakePipe(List(), createSymbolTableFor("name"))

val returnItems = List()
val grouping = List(ValueAggregationItem(Count((Entity("name")))))
intercept[ThisShouldNotHappenError](new OrderedAggregationPipe(source, returnItems, grouping))
}

private def createSymbolTableFor(names: String*) = new SymbolTable( names.map( Identifier(_, NodeType()) ):_* )

}

0 comments on commit bd1851f

Please sign in to comment.