Skip to content

Commit

Permalink
Testing the HashJoinPipe
Browse files Browse the repository at this point in the history
  • Loading branch information
systay committed Sep 17, 2012
1 parent d357017 commit 0f752d0
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
@@ -0,0 +1,57 @@
/**
* Copyright (c) 2002-2012 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.pipes


import collection.mutable

class HashJoinPipe(a: Pipe, b: Pipe) extends Pipe {
val keySet = a.symbols.identifiers.keySet.intersect(b.symbols.identifiers.keySet)
assert(keySet.nonEmpty, "No overlap between the incoming pipes exist")

val keySeq = keySet.toSeq

def createResults(state: QueryState): Traversable[ExecutionContext] = {
val table = buildTable(a.createResults(state))

b.createResults(state).flatMap { (entry) =>
table.get(computeKey(entry)) match {
case Some(aList) => aList.map { _.newWith(entry) }
case None => Seq.empty
}
}
}

protected def buildTable(iter: scala.Traversable[ExecutionContext]):
mutable.HashMap[Seq[Any], List[ExecutionContext]] = {
val table = mutable.HashMap[Seq[Any], List[ExecutionContext]]()
for (entry <- iter) {
val key = computeKey(entry)
val l = table.getOrElse(key, List())
table(key) = entry :: l
}
table
}

def computeKey(m: mutable.Map[String, Any]): Seq[Any] = keySeq.map { m(_) }
def symbols = a.symbols.add(b.symbols.identifiers)

def executionPlan() = "HashJoin"
}
@@ -0,0 +1,80 @@
/**
* Copyright (c) 2002-2012 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.pipes

import org.junit.Test
import org.neo4j.cypher.internal.symbols.{StringType, NumberType}
import org.scalatest.Assertions

class HashJoinPipeTest extends Assertions {
@Test def should_not_return_when_missing_matches() {
val a = new FakePipe(Seq(Map("a" -> 0)), "a" -> NumberType())
val b = new FakePipe(Seq(Map("a" -> 1)), "a" -> NumberType())
val joiner = new HashJoinPipe(a, b)

assert(joiner.createResults(QueryState()).isEmpty, "We should not find any matches")
}

@Test def should_complain_about_missing_overlaps() {
val a = new FakePipe(Seq(Map("a" -> 0)), "a" -> NumberType())
val b = new FakePipe(Seq(Map("b" -> 1)), "b" -> NumberType())

intercept[AssertionError](new HashJoinPipe(a, b))
}

@Test def should_join_stuff() {
val a = new FakePipe(Seq(Map("a" -> 0)), "a" -> NumberType())
val b = new FakePipe(Seq(Map("a" -> 0)), "a" -> NumberType())

val joiner = new HashJoinPipe(a, b)

val results = joiner.createResults(QueryState()).toList

assert(results === List(Map("a" -> 0)))
}

@Test def should_join_stuff2() {
val a = new FakePipe(Seq(Map("a" -> 0, "b" -> "Andres")), "a" -> NumberType(), "b" -> StringType())
val b = new FakePipe(Seq(Map("a" -> 0, "c" -> "Stefan")), "a" -> NumberType(), "c" -> StringType())

val joiner = new HashJoinPipe(a, b)

val results = joiner.createResults(QueryState()).toList

assert(results === List(Map("a" -> 0, "b" -> "Andres", "c" -> "Stefan")))
}

@Test def should_join_stuff3() {
val a = new FakePipe(Seq(
Map("a" -> 0, "b" -> "Andres1"),
Map("a" -> 1, "b" -> "Andres2")
), "a" -> NumberType(), "b" -> StringType())
val b = new FakePipe(Seq(
Map("a" -> 0, "c" -> "Stefan1"),
Map("a" -> 2, "c" -> "Stefan2")
), "a" -> NumberType(), "c" -> StringType())

val joiner = new HashJoinPipe(a, b)

val results = joiner.createResults(QueryState()).toList

assert(results === List(Map("a" -> 0, "b" -> "Andres1", "c" -> "Stefan1")))
}
}

0 comments on commit 0f752d0

Please sign in to comment.