Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Refactoring of RelateAction

  • Loading branch information...
commit 0c8b26a68e33c40f0284e8f8c315953a8446e6e8 1 parent e1da572
@systay authored
View
119 cypher/src/main/scala/org/neo4j/cypher/internal/mutation/RelateAction.scala
@@ -23,8 +23,8 @@ import org.neo4j.cypher.internal.symbols.Identifier
import org.neo4j.cypher.internal.pipes.{QueryState, ExecutionContext}
import org.neo4j.helpers.ThisShouldNotHappenError
import org.neo4j.cypher.internal.commands.{StartItem, Expression}
-import org.neo4j.cypher.RelatePathNotUnique
-import org.neo4j.graphdb.{Lock, PropertyContainer}
+import org.neo4j.graphdb.Lock
+import org.neo4j.cypher.{PatternException, RelatePathNotUnique}
case class RelateAction(links: RelateLink*) extends UpdateAction {
def dependencies: Seq[Identifier] = links.flatMap(_.dependencies)
@@ -32,66 +32,51 @@ case class RelateAction(links: RelateLink*) extends UpdateAction {
def exec(context: ExecutionContext, state: QueryState): Traversable[ExecutionContext] = {
var linksToDo: Seq[RelateLink] = links
var ctx = context
+
while (linksToDo.nonEmpty) {
- val results: Seq[(RelateLink, RelateResult)] = executeAllRemainingPatterns(linksToDo, ctx, state)
- linksToDo = results.map(_._1)
- val updateCommands = extractUpdateCommands(results)
- val traversals = extractTraversals(results)
-
- if (results.isEmpty) {
- Stream(ctx) //We're done
- } else if (canNotAdvanced(results)) {
- throw new Exception("Unbound pattern!") //None of the patterns can advance. Fail.
- } else if (traversals.nonEmpty) {
- ctx = traverseNextStep(traversals, ctx) //We've found some way to move forward. Let's use it
- } else if (updateCommands.nonEmpty) {
- val locks = updateCommands.flatMap(_.lock()) //Failed to find a way forward - lock stuff up, and check again
- try {
- tryAgain(linksToDo, ctx, state)
- } finally {
- locks.foreach(_.release())
- }
- } else {
- throw new ThisShouldNotHappenError("Andres", "There was something in that result list I don't know how to handle.")
+
+ val results = linksToDo.map(link => link.exec(ctx, state))
+
+ val result = results.reduce(_ reduceWith _)
+
+ linksToDo = result.leftToDo
+
+ result match {
+ case Done() => //We're done! Let's go home
+ case Traverse(a, _) => ctx = ctx.newWith(a.toMap)
+ case Update(commands, locker, _) => ctx = tryAgainWithLocks(locker, linksToDo, ctx, state)
+ case CanNotAdvance(todo) => throw new PatternException("Unbound relate pattern found " + todo)
}
}
Stream(ctx)
}
- private def tryAgain(linksToDo: Seq[RelateLink], context: ExecutionContext, state: QueryState): ExecutionContext = {
- val results: Seq[(RelateLink, RelateResult)] = executeAllRemainingPatterns(linksToDo, context, state)
- val updateCommands = extractUpdateCommands(results)
- val traversals = extractTraversals(results)
-
- if (results.isEmpty) {
- throw new ThisShouldNotHappenError("Andres", "Second check should never return empty result set")
- } else if (canNotAdvanced(results)) {
- throw new ThisShouldNotHappenError("Andres", "Second check should never fail to move forward")
- } else if (traversals.nonEmpty) {
- traverseNextStep(traversals, context) //Ah, so this time we did find a traversal way forward. Great!
- } else if (updateCommands.nonEmpty) {
- runUpdateCommands(updateCommands.flatMap(_.cmds), context, state) //If we still can't find a way forward,
- } else { // let's build one
- throw new ThisShouldNotHappenError("Andres", "There was something in that result list I don't know how to handle.")
- }
- }
+ def tryAgainWithLocks(locker: () => scala.Seq[Lock], linksToDo: scala.Seq[RelateLink], ctx: ExecutionContext, state: QueryState): ExecutionContext = {
+ val locks = locker()
- private def traverseNextStep(nextSteps: Seq[(String, PropertyContainer)], oldContext: ExecutionContext): ExecutionContext = {
- val uniqueKVPs = nextSteps.distinct
- val uniqueKeys = nextSteps.toMap
+ try {
+ tryAgain(linksToDo, ctx, state)
+ } finally {
+ locks.foreach(_.release())
+ }
+ }
- if (uniqueKeys.size != uniqueKVPs.size) {
- //We can only go forward following a unique path. Fail.
- throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
- } else {
- oldContext.newWith(uniqueKeys)
+ private def tryAgain(linksToDo: Seq[RelateLink], context: ExecutionContext, state: QueryState): ExecutionContext = {
+ val result = linksToDo.
+ map(link => link.exec(context, state)).
+ reduce(_ reduceWith _)
+
+ result match {
+ case Traverse(a, _) => context.newWith(a.toMap)
+ case Update(commands, locker, _) => runUpdateCommands(commands, context, state)
+ case _ => throw new ThisShouldNotHappenError("Andres", "Second relate check should only return traverse or update")
}
}
- private def runUpdateCommands(cmds: Seq[UpdateWrapper], oldContext: ExecutionContext, state: QueryState): ExecutionContext = {
- var context = oldContext
+ private def runUpdateCommands(cmds: Seq[UpdateWrapper], startContext: ExecutionContext, state: QueryState): ExecutionContext = {
+ var context = startContext
var todo = cmds.distinct
var done = Seq[String]()
@@ -99,41 +84,23 @@ case class RelateAction(links: RelateLink*) extends UpdateAction {
val (unfiltered, temp) = todo.partition(_.canRun(context))
todo = temp
- val current = unfiltered.filterNot(cmd => done.contains(cmd.cmd.identifierName))
+ val current: Seq[UpdateWrapper] = unfiltered.filterNot(cmd => done.contains(cmd.cmd.identifierName))
done = done ++ current.map(_.cmd.identifierName)
context = current.foldLeft(context) {
- case (currentContext, updateCommand) => {
+ case (currentContext, updateCommand) =>
val result = updateCommand.cmd.exec(currentContext, state)
if (result.size != 1) {
throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
} else {
result.head
}
-
- }
}
}
context
}
- private def extractUpdateCommands(results: scala.Seq[(RelateLink, RelateResult)]): Seq[Update] =
- results.flatMap {
- case (_, u: Update) => Some(u)
- case _ => None
- }
-
- private def extractTraversals(results: scala.Seq[(RelateLink, RelateResult)]): Seq[(String, PropertyContainer)] =
- results.flatMap {
- case (_, Traverse(ctx@_*)) => ctx
- case _ => None
- }
-
- private def executeAllRemainingPatterns(linksToDo: Seq[RelateLink], ctx: ExecutionContext, state: QueryState): Seq[(RelateLink, RelateResult)] = linksToDo.flatMap(link => link.exec(ctx, state))
-
- private def canNotAdvanced(results: scala.Seq[(RelateLink, RelateResult)]) = results.forall(_._2 == CanNotAdvance())
-
def filter(f: (Expression) => Boolean): Seq[Expression] = links.flatMap(_.filter(f)).distinct
def identifier: Seq[Identifier] = links.flatMap(_.identifier).distinct
@@ -141,21 +108,11 @@ case class RelateAction(links: RelateLink*) extends UpdateAction {
def rewrite(f: (Expression) => Expression): UpdateAction = RelateAction(links.map(_.rewrite(f)): _*)
}
-sealed abstract class RelateResult
-
-case class CanNotAdvance() extends RelateResult
-
-case class Traverse(result: (String, PropertyContainer)*) extends RelateResult
-
-case class Update(cmds: Seq[UpdateWrapper], locker: () => Seq[Lock]) extends RelateResult {
- def lock(): Seq[Lock] = locker()
-}
case class UpdateWrapper(needs: Seq[String], cmd: StartItem with UpdateAction) {
def canRun(context: ExecutionContext) = {
- lazy val keySet = context.keySet
- val forall = needs.forall(keySet.contains)
- forall
+ val keySet = context.keySet
+ needs.forall(keySet.contains)
}
}
View
44 cypher/src/main/scala/org/neo4j/cypher/internal/mutation/RelateLink.scala
@@ -36,7 +36,7 @@ case class NamedExpectation(name: String, properties: Map[String, Expression])
case ("*", expression) => getMapFromExpression(expression(ctx)).forall {
case (k, value) => pc.hasProperty(k) && pc.getProperty(k) == value
}
- case (k, exp) =>
+ case (k, exp) =>
if (!pc.hasProperty(k)) false
else {
val expectationValue = exp(ctx)
@@ -57,7 +57,7 @@ case class RelateLink(start: NamedExpectation, end: NamedExpectation, rel: Named
extends GraphElementPropertyFunctions {
lazy val relationshipType = DynamicRelationshipType.withName(relType)
- def exec(context: ExecutionContext, state: QueryState): Option[(RelateLink, RelateResult)] = {
+ def exec(context: ExecutionContext, state: QueryState): RelateResult = {
// We haven't yet figured out if we already have both elements in the context
// so let's start by finding that first
@@ -65,14 +65,14 @@ case class RelateLink(start: NamedExpectation, end: NamedExpectation, rel: Named
val e = getNode(context, end.name)
(s, e) match {
- case (None, None) => Some(this->CanNotAdvance())
+ case (None, None) => CanNotAdvance(Seq(this))
case (Some(startNode), None) => oneNode(startNode, context, dir, state, end)
case (None, Some(startNode)) => oneNode(startNode, context, dir.reverse(), state, start)
case (Some(startNode), Some(endNode)) => {
if (context.contains(rel.name))
- None //We've already solved this pattern.
+ Done() //We've already solved this pattern.
else
twoNodes(startNode, endNode, context, state)
}
@@ -82,49 +82,49 @@ case class RelateLink(start: NamedExpectation, end: NamedExpectation, rel: Named
// This method sees if a matching relationship already exists between two nodes
// If any matching rels are found, they are returned. Otherwise, a new one is
// created and returned.
- private def twoNodes(startNode: Node, endNode: Node, ctx: ExecutionContext, state: QueryState): Option[(RelateLink, RelateResult)] = {
- val rels = startNode.getRelationships(relationshipType, dir).asScala.
+ private def twoNodes(startNode: Node, endNode: Node, ctx: ExecutionContext, state: QueryState): RelateResult = {
+ val matchingRelationships = startNode.getRelationships(relationshipType, dir).asScala.
filter(r => {
r.getOtherNode(startNode) == endNode && rel.compareWithExpectations(r, ctx)
}).toList
- rels match {
+ matchingRelationships match {
case List() =>
val tx = state.transaction.getOrElse(throw new RuntimeException("I need a transaction!"))
+ val locker = () => Seq(tx.acquireWriteLock(startNode), tx.acquireWriteLock(endNode))
+ val wrapper = UpdateWrapper(Seq(), CreateRelationshipStartItem(rel.name, (Literal(startNode), Map()), (Literal(endNode), Map()), relType, rel.properties))
+ Update(Seq(wrapper), locker, Seq(this))
- Some(this->Update(Seq(UpdateWrapper(Seq(), CreateRelationshipStartItem(rel.name, (Literal(startNode), Map()), (Literal(endNode), Map()), relType, rel.properties))), () => {
- Seq(tx.acquireWriteLock(startNode), tx.acquireWriteLock(endNode))
- }))
- case List(r) => Some(this->Traverse(rel.name -> r))
- case _ => throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
+ case List(r) => Traverse(result = Seq(rel.name -> r), leftToDo = Seq())
+ case _ => throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
}
}
// When only one node exists in the context, we'll traverse all the relationships of that node
// and try to find a matching node/rel. If matches are found, they are returned. If nothing is
// found, we'll create it and return it
- private def oneNode(startNode: Node, ctx: ExecutionContext, dir: Direction, state: QueryState, end: NamedExpectation): Option[(RelateLink, RelateResult)] = {
+ private def oneNode(startNode: Node, ctx: ExecutionContext, dir: Direction, state: QueryState, end: NamedExpectation): RelateResult = {
val rels = startNode.getRelationships(relationshipType, dir).asScala.filter(r => {
rel.compareWithExpectations(r, ctx) && end.compareWithExpectations(r.getOtherNode(startNode), ctx)
}).toList
rels match {
- case List() =>
+ case List() =>
val tx = state.transaction.getOrElse(throw new RuntimeException("I need a transaction!"))
- Some(this ->Update(createUpdateActions(dir, startNode, end), () => {
- Seq(tx.acquireWriteLock(startNode))
- }))
- case List(r) => Some(this->Traverse(rel.name -> r, end.name -> r.getOtherNode(startNode)))
- case _ => throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
+ val locker = () => Seq(tx.acquireWriteLock(startNode))
+ Update(createUpdateActions(dir, startNode, end), locker, Seq(this))
+
+ case List(r) => Traverse(Seq(rel.name -> r, end.name -> r.getOtherNode(startNode)), Seq())
+ case _ => throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
}
}
private def createUpdateActions(dir: Direction, startNode: Node, end: NamedExpectation): Seq[UpdateWrapper] = {
val createRel = if (dir == Direction.OUTGOING) {
- CreateRelationshipStartItem(rel.name, (Literal(startNode),Map()), (Entity(end.name),Map()), relType, rel.properties)
+ CreateRelationshipStartItem(rel.name, (Literal(startNode), Map()), (Entity(end.name), Map()), relType, rel.properties)
} else {
- CreateRelationshipStartItem(rel.name, (Entity(end.name),Map()), (Literal(startNode),Map()), relType, rel.properties)
+ CreateRelationshipStartItem(rel.name, (Entity(end.name), Map()), (Literal(startNode), Map()), relType, rel.properties)
}
val relUpdate = UpdateWrapper(Seq(end.name), createRel)
@@ -137,7 +137,7 @@ case class RelateLink(start: NamedExpectation, end: NamedExpectation, rel: Named
None
} else context.get(key).map {
case n: Node => n
- case x => throw new CypherTypeException("Expected `" + key + "` to a node, but it is a " + x)
+ case x => throw new CypherTypeException("Expected `" + key + "` to a node, but it is a " + x)
}
lazy val identifier = Seq(Identifier(start.name, NodeType()), Identifier(end.name, NodeType()), Identifier(rel.name, RelationshipType()))
View
88 cypher/src/main/scala/org/neo4j/cypher/internal/mutation/RelateResult.scala
@@ -0,0 +1,88 @@
+/**
+ * Copyright (c) 2002-2012 "Neo Technology,"
+ * Network Engine for Objects in Lund AB [http://neotechnology.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+package org.neo4j.cypher.internal.mutation
+
+import org.neo4j.graphdb.{Lock, PropertyContainer}
+import org.neo4j.cypher.RelatePathNotUnique
+
+sealed abstract class RelateResult {
+ def reduceWith(x: RelateResult): RelateResult = (this, x) match {
+
+ case (Done(), Done()) => Done()
+ case (Done(), other) => other
+ case (other, Done()) => other
+
+
+ case (
+ CanNotAdvance(todoA),
+ CanNotAdvance(todoB)) => CanNotAdvance(todoA ++ todoB)
+ case (CanNotAdvance(todoA), other) => other.addTodo(todoA)
+ case (other, CanNotAdvance(todoA)) => other.addTodo(todoA)
+
+
+ case (
+ Traverse(leftToDoA, resultA),
+ Traverse(leftToDoB, resultB)) => Traverse(leftToDoA ++ leftToDoB, resultA ++ resultB)
+ case (traverse: Traverse, other) => traverse.addTodo(other.leftToDo)
+ case (other, traverse: Traverse) => traverse.addTodo(other.leftToDo)
+
+
+ case (
+ Update(cmdsA, lockerA, leftToDoA),
+ Update(cmdsB, lockerB, leftToDoB)) => Update(cmdsA ++ cmdsB, () => lockerA() ++ lockerB(), leftToDoA ++ leftToDoB)
+ }
+
+
+ def leftToDo: Seq[RelateLink]
+
+ def addTodo(todo: Seq[RelateLink]): RelateResult
+}
+
+case class Done() extends RelateResult {
+ def leftToDo = Seq()
+
+ def addTodo(todo: Seq[RelateLink]): RelateResult = throw new Exception("Done can't add todos")
+}
+
+case class CanNotAdvance(leftToDo: Seq[RelateLink]) extends RelateResult {
+ def addTodo(todo: Seq[RelateLink]): RelateResult = CanNotAdvance(leftToDo ++ todo)
+}
+
+case class Traverse(result: Seq[(String, PropertyContainer)], leftToDo: Seq[RelateLink]) extends RelateResult {
+ assertValidResult()
+
+ def addTodo(todo: Seq[RelateLink]): RelateResult = Traverse(result, leftToDo ++ todo)
+
+ private def assertValidResult() {
+ val uniqueKVPs = result.distinct
+ val uniqueKeys = result.toMap
+
+ if (uniqueKeys.size != uniqueKVPs.size) {
+ //We can only go forward following a unique path. Fail.
+ throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
+ }
+ }
+}
+
+case class Update(cmds: Seq[UpdateWrapper], locker: () => Seq[Lock], leftToDo: Seq[RelateLink]) extends RelateResult {
+ def lock(): Seq[Lock] = locker()
+
+ def addTodo(todo: Seq[RelateLink]): RelateResult = Update(cmds, locker, leftToDo ++ todo)
+}
View
4 cypher/src/test/scala/org/neo4j/cypher/RelateAcceptanceTests.scala
@@ -79,7 +79,7 @@ class RelateAcceptanceTests extends ExecutionEngineHelper with Assertions with S
val c = createNode()
val d = createNode()
- val result = parseAndExecute("start a = node(1,2), b=node(3), c=node(4) relate a-[:X]->b-[:X]->c")
+ val result = parseAndExecute("start a = node(1,2), b=node(3), c=node(4) relate a-[r1:X]->b-[r2:X]->c")
assertStats(result, relationshipsCreated = 3)
@@ -93,7 +93,7 @@ class RelateAcceptanceTests extends ExecutionEngineHelper with Assertions with S
def creates_minimal_amount_of_nodes_reverse() {
val a = createNode()
- val result = parseAndExecute("start a = node(1) relate c-[:X]->b-[:X]->a")
+ val result = parseAndExecute("start a = node(1) relate c-[r2:X]->b-[r1:X]->a")
assertStats(result, nodesCreated = 2, relationshipsCreated = 2)
View
89 cypher/src/test/scala/org/neo4j/cypher/internal/mutation/RelateResultTest.scala
@@ -0,0 +1,89 @@
+/**
+ * Copyright (c) 2002-2012 "Neo Technology,"
+ * Network Engine for Objects in Lund AB [http://neotechnology.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+package org.neo4j.cypher.internal.mutation
+
+import org.junit.Test
+import org.junit.Assert._
+import org.neo4j.cypher.internal.commands.{Entity, CreateRelationshipStartItem}
+import org.hamcrest.CoreMatchers._
+import org.neo4j.graphdb.{PropertyContainer, Direction}
+
+class RelateResultTest {
+ val A = RelateLink("a", "b", "r", "X", Direction.OUTGOING)
+ val B = RelateLink("a", "c", "r", "X", Direction.OUTGOING)
+ val C = RelateLink("a", "d", "r", "X", Direction.OUTGOING)
+ val cmd1 = UpdateWrapper(Seq("a", "b"), CreateRelationshipStartItem("r", (Entity("a"), Map()), (Entity("b"), Map()), "x", Map()))
+ val cmd2 = UpdateWrapper(Seq("a", "c"), CreateRelationshipStartItem("r", (Entity("a"), Map()), (Entity("c"), Map()), "x", Map()))
+
+ val done1 = Done()
+ val done2 = Done()
+ val cantAdvance1 = CanNotAdvance(leftToDo = Seq(A))
+ val cantAdvance2 = CanNotAdvance(leftToDo = Seq(B))
+ val traverse1 = Traverse(leftToDo = Seq(A), result = Seq("a" -> null))
+ val traverse2 = Traverse(leftToDo = Seq(C), result = Seq("b" -> null))
+ val update1 = Update(Seq(cmd1), () => Seq(), Seq(C))
+ val update2 = Update(Seq(cmd2), () => Seq(), Seq(B))
+
+
+ @Test def testDone() {
+ assertThat(done1 reduceWith done2, instanceOf(classOf[Done]))
+ assertThat(done1 reduceWith cantAdvance1, equalTo[RelateResult](cantAdvance1))
+ assertThat(done1 reduceWith traverse1, equalTo[RelateResult](traverse1))
+ assertThat(done1 reduceWith update1, equalTo[RelateResult](update1))
+ }
+
+ @Test def cant_advanced_with_cant_advance() {
+ val reduced = cantAdvance1 reduceWith cantAdvance2
+ assertThat(reduced.leftToDo, equalTo(Seq(A, B)))
+ assertThat(reduced, instanceOf(classOf[CanNotAdvance]))
+ }
+
+ @Test def cant_advanced_with_traverse() {
+ val reduced = cantAdvance1 reduceWith traverse2
+ assertThat(reduced.leftToDo.toSet, equalTo(Set(A, C)))
+ assertThat(reduced, instanceOf(classOf[Traverse]))
+ }
+
+ @Test def cant_advanced_with_update() {
+ val reduced = cantAdvance1 reduceWith update1
+ assertThat(reduced.leftToDo.toSet, equalTo(Set(A, C)))
+ assertThat(reduced, instanceOf(classOf[Update]))
+ }
+
+ @Test def traverse_w_traverse() {
+ val reduced = traverse1 reduceWith traverse2
+ assertThat(reduced.leftToDo.toSet, equalTo(Set(A, C)))
+ assertThat(reduced, instanceOf(classOf[Traverse]))
+ assertThat(reduced.asInstanceOf[Traverse].result, equalTo(Seq[(String, PropertyContainer)]("a" -> null, "b" -> null)))
+ }
+
+ @Test def traverse_w_update() {
+ val reduced = traverse1 reduceWith update2
+ assertThat(reduced.leftToDo.toSet, equalTo(Set(A, B)))
+ assertThat(reduced, instanceOf(classOf[Traverse]))
+ }
+
+ @Test def update_w_update() {
+ val reduced = (update1 reduceWith update2).asInstanceOf[Update]
+ assertThat(reduced.leftToDo.toSet, equalTo(Set(B, C)))
+ assertThat(reduced, instanceOf(classOf[Update]))
+ assertThat(reduced.cmds, equalTo(Seq(cmd1, cmd2)))
+ }
+}
Please sign in to comment.
Something went wrong with that request. Please try again.