Browse files

Merge pull request #594 from systay/relate-lock

RELATE now locks nodes before creating new graph elements
  • Loading branch information...
2 parents 35f4485 + c03ce5e commit f4242561766bd6830ba524e992711413fb137c3e @jexp jexp committed Jun 9, 2012
View
55 cypher/src/main/scala/org/neo4j/cypher/internal/mutation/RelateAction.scala
@@ -22,9 +22,9 @@ package org.neo4j.cypher.internal.mutation
import org.neo4j.cypher.internal.symbols.Identifier
import org.neo4j.cypher.internal.pipes.{QueryState, ExecutionContext}
import org.neo4j.helpers.ThisShouldNotHappenError
-import org.neo4j.graphdb.PropertyContainer
import org.neo4j.cypher.internal.commands.{StartItem, Expression}
import org.neo4j.cypher.RelatePathNotUnique
+import org.neo4j.graphdb.{Lock, PropertyContainer}
case class RelateAction(links: RelateLink*) extends UpdateAction {
def dependencies: Seq[Identifier] = links.flatMap(_.dependencies)
@@ -33,7 +33,7 @@ case class RelateAction(links: RelateLink*) extends UpdateAction {
var linksToDo: Seq[RelateLink] = links
var ctx = context
while (linksToDo.nonEmpty) {
- val results = executeAllRemainingPatterns(linksToDo, ctx, state)
+ val results: Seq[(RelateLink, RelateResult)] = executeAllRemainingPatterns(linksToDo, ctx, state)
linksToDo = results.map(_._1)
val updateCommands = extractUpdateCommands(results)
val traversals = extractTraversals(results)
@@ -45,7 +45,12 @@ case class RelateAction(links: RelateLink*) extends UpdateAction {
} else if (traversals.nonEmpty) {
ctx = traverseNextStep(traversals, ctx) //We've found some way to move forward. Let's use it
} else if (updateCommands.nonEmpty) {
- ctx = runUpdateCommands(updateCommands, ctx, state) //We could not move forward by traversing - let's build the road
+ val locks = updateCommands.flatMap(_.lock()) //Failed to find a way forward - lock stuff up, and check again
+ try {
+ tryAgain(linksToDo, ctx, state)
+ } finally {
+ locks.foreach(_.release())
+ }
} else {
throw new ThisShouldNotHappenError("Andres", "There was something in that result list I don't know how to handle.")
}
@@ -54,12 +59,34 @@ case class RelateAction(links: RelateLink*) extends UpdateAction {
Stream(ctx)
}
+ private def tryAgain(linksToDo: Seq[RelateLink], context: ExecutionContext, state: QueryState): ExecutionContext = {
+ val results: Seq[(RelateLink, RelateResult)] = executeAllRemainingPatterns(linksToDo, context, state)
+ val updateCommands = extractUpdateCommands(results)
+ val traversals = extractTraversals(results)
+
+ if (results.isEmpty) {
+ throw new ThisShouldNotHappenError("Andres", "Second check should never return empty result set")
+ } else if (canNotAdvanced(results)) {
+ throw new ThisShouldNotHappenError("Andres", "Second check should never fail to move forward")
+ } else if (traversals.nonEmpty) {
+ traverseNextStep(traversals, context) //Ah, so this time we did find a traversal way forward. Great!
+ } else if (updateCommands.nonEmpty) {
+ runUpdateCommands(updateCommands.flatMap(_.cmds), context, state) //If we still can't find a way forward,
+ } else { // let's build one
+ throw new ThisShouldNotHappenError("Andres", "There was something in that result list I don't know how to handle.")
+ }
+
+ }
+
private def traverseNextStep(nextSteps: Seq[(String, PropertyContainer)], oldContext: ExecutionContext): ExecutionContext = {
- if (nextSteps.size != nextSteps.distinct.size) {
+ val uniqueKVPs = nextSteps.distinct
+ val uniqueKeys = nextSteps.toMap
+
+ if (uniqueKeys.size != uniqueKVPs.size) {
//We can only go forward following a unique path. Fail.
throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
} else {
- oldContext.newWith(nextSteps)
+ oldContext.newWith(uniqueKeys)
}
}
@@ -91,9 +118,9 @@ case class RelateAction(links: RelateLink*) extends UpdateAction {
context
}
- private def extractUpdateCommands(results: scala.Seq[(RelateLink, RelateResult)]): Seq[UpdateWrapper] =
+ private def extractUpdateCommands(results: scala.Seq[(RelateLink, RelateResult)]): Seq[Update] =
results.flatMap {
- case (_, Update(cmds@_*)) => cmds
+ case (_, u: Update) => Some(u)
case _ => None
}
@@ -103,13 +130,7 @@ case class RelateAction(links: RelateLink*) extends UpdateAction {
case _ => None
}
- private def executeAllRemainingPatterns(linksToDo: Seq[RelateLink], ctx: ExecutionContext, state: QueryState): Seq[(RelateLink, RelateResult)] =
- linksToDo.flatMap(link => {
- link.exec(ctx, state) match {
- case Done() => None
- case result => Some(link -> result)
- }
- })
+ private def executeAllRemainingPatterns(linksToDo: Seq[RelateLink], ctx: ExecutionContext, state: QueryState): Seq[(RelateLink, RelateResult)] = linksToDo.flatMap(link => link.exec(ctx, state))
private def canNotAdvanced(results: scala.Seq[(RelateLink, RelateResult)]) = results.forall(_._2 == CanNotAdvance())
@@ -122,13 +143,13 @@ case class RelateAction(links: RelateLink*) extends UpdateAction {
sealed abstract class RelateResult
-case class Done() extends RelateResult
-
case class CanNotAdvance() extends RelateResult
case class Traverse(result: (String, PropertyContainer)*) extends RelateResult
-case class Update(cmds: UpdateWrapper*) extends RelateResult
+case class Update(cmds: Seq[UpdateWrapper], locker: () => Seq[Lock]) extends RelateResult {
+ def lock(): Seq[Lock] = locker()
+}
case class UpdateWrapper(needs: Seq[String], cmd: StartItem with UpdateAction) {
def canRun(context: ExecutionContext) = {
View
27 cypher/src/main/scala/org/neo4j/cypher/internal/mutation/RelateLink.scala
@@ -57,22 +57,22 @@ case class RelateLink(start: NamedExpectation, end: NamedExpectation, rel: Named
extends GraphElementPropertyFunctions {
lazy val relationshipType = DynamicRelationshipType.withName(relType)
- def exec(context: ExecutionContext, state: QueryState): RelateResult = {
+ def exec(context: ExecutionContext, state: QueryState): Option[(RelateLink, RelateResult)] = {
// We haven't yet figured out if we already have both elements in the context
// so let's start by finding that first
val s = getNode(context, start.name)
val e = getNode(context, end.name)
(s, e) match {
- case (None, None) => CanNotAdvance()
+ case (None, None) => Some(this->CanNotAdvance())
case (Some(startNode), None) => oneNode(startNode, context, dir, state, end)
case (None, Some(startNode)) => oneNode(startNode, context, dir.reverse(), state, start)
case (Some(startNode), Some(endNode)) => {
if (context.contains(rel.name))
- Done() //We've already solved this pattern.
+ None //We've already solved this pattern.
else
twoNodes(startNode, endNode, context, state)
}
@@ -82,30 +82,39 @@ case class RelateLink(start: NamedExpectation, end: NamedExpectation, rel: Named
// This method sees if a matching relationship already exists between two nodes
// If any matching rels are found, they are returned. Otherwise, a new one is
// created and returned.
- private def twoNodes(startNode: Node, endNode: Node, ctx: ExecutionContext, state: QueryState): RelateResult = {
+ private def twoNodes(startNode: Node, endNode: Node, ctx: ExecutionContext, state: QueryState): Option[(RelateLink, RelateResult)] = {
val rels = startNode.getRelationships(relationshipType, dir).asScala.
filter(r => {
r.getOtherNode(startNode) == endNode && rel.compareWithExpectations(r, ctx)
}).toList
rels match {
- case List() => Update(UpdateWrapper(Seq(), CreateRelationshipStartItem(rel.name, (Literal(startNode),Map()), (Literal(endNode),Map()), relType, rel.properties)))
- case List(r) => Traverse(rel.name -> r)
+ case List() =>
+ val tx = state.transaction.getOrElse(throw new RuntimeException("I need a transaction!"))
+
+ Some(this->Update(Seq(UpdateWrapper(Seq(), CreateRelationshipStartItem(rel.name, (Literal(startNode), Map()), (Literal(endNode), Map()), relType, rel.properties))), () => {
+ Seq(tx.acquireWriteLock(startNode), tx.acquireWriteLock(endNode))
+ }))
+ case List(r) => Some(this->Traverse(rel.name -> r))
case _ => throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
}
}
// When only one node exists in the context, we'll traverse all the relationships of that node
// and try to find a matching node/rel. If matches are found, they are returned. If nothing is
// found, we'll create it and return it
- private def oneNode(startNode: Node, ctx: ExecutionContext, dir: Direction, state: QueryState, end: NamedExpectation): RelateResult = {
+ private def oneNode(startNode: Node, ctx: ExecutionContext, dir: Direction, state: QueryState, end: NamedExpectation): Option[(RelateLink, RelateResult)] = {
val rels = startNode.getRelationships(relationshipType, dir).asScala.filter(r => {
rel.compareWithExpectations(r, ctx) && end.compareWithExpectations(r.getOtherNode(startNode), ctx)
}).toList
rels match {
- case List() => Update(createUpdateActions(dir, startNode, end): _*)
- case List(r) => Traverse(rel.name -> r, end.name -> r.getOtherNode(startNode))
+ case List() =>
+ val tx = state.transaction.getOrElse(throw new RuntimeException("I need a transaction!"))
+ Some(this ->Update(createUpdateActions(dir, startNode, end), () => {
+ Seq(tx.acquireWriteLock(startNode))
+ }))
+ case List(r) => Some(this->Traverse(rel.name -> r, end.name -> r.getOtherNode(startNode)))
case _ => throw new RelatePathNotUnique("The pattern " + this + " produced multiple possible paths, and that is not allowed")
}
}
View
160 cypher/src/test/scala/org/neo4j/cypher/internal/mutation/RelateUniqueTest.scala
@@ -0,0 +1,160 @@
+/**
+ * Copyright (c) 2002-2012 "Neo Technology,"
+ * Network Engine for Objects in Lund AB [http://neotechnology.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+package org.neo4j.cypher.internal.mutation
+
+import org.junit.Test
+import org.scalatest.Assertions
+import org.neo4j.test.ImpermanentGraphDatabase
+import java.lang.Iterable
+import org.neo4j.graphdb.Traverser.Order
+import org.neo4j.graphdb._
+import org.neo4j.cypher.internal.pipes.{ExecutionContext, QueryState}
+import collection.JavaConverters._
+import collection.mutable.Map
+
+
+/*
+This test tries to set up a situation where RELATE would fail, unless we guard with locks to prevent creating
+multiple relationships.
+
+It does so by using a decorator around ImpermanentGraphDatabase, so directly after RELATE has done getRelationships on
+a node, we'll create a new relationship.
+*/
+
+class RelateUniqueTest extends Assertions {
+ var done = false
+ val db = new ImpermanentGraphDatabase() with TripIt
+
+ @Test def double_check_relate() {
+
+ db.afterGetRelationship = createRel
+
+ val tx = db.beginTx
+
+ val a = try {
+ val a = db.createNode()
+
+ relateAction.exec(createExecutionContext(a), createQueryState(tx))
+
+ tx.success()
+ a
+ } finally {
+ tx.finish()
+ }
+
+ assert(a.getRelationships.asScala.size === 1)
+ }
+
+ val relateAction = RelateAction(RelateLink("a", "b", "r", "X", Direction.OUTGOING))
+
+
+ private def createExecutionContext(a: Node): ExecutionContext = {
+ ExecutionContext.empty.newWith(Map("a" -> a))
+ }
+
+ private def createQueryState(tx: Transaction): QueryState = {
+ new QueryState(db, Map(), Some(tx))
+ }
+
+ private def createRel(node:Node) {
+ if (!done) {
+ done = true
+ val x = db.createNode()
+ node.createRelationshipTo(x, DynamicRelationshipType.withName("X"))
+ }
+
+ }
+}
+
+trait TripIt extends GraphDatabaseService {
+ var afterGetRelationship: Node => Unit = (n) => {}
+
+ abstract override def createNode(): Node = {
+ val n = super.createNode()
+ new PausingNode(n, afterGetRelationship)
+ }
+}
+
+class PausingNode(n: Node, afterGetRelationship: Node => Unit) extends Node {
+ def getId: Long = n.getId
+
+ def delete() {
+ throw new RuntimeException
+ }
+
+ def getRelationships: Iterable[Relationship] = {
+ val rels = n.getRelationships.asScala.toList
+ afterGetRelationship(n)
+ rels.toIterable.asJava
+ }
+
+ def hasRelationship: Boolean = throw new RuntimeException
+
+ def getRelationships(types: RelationshipType*): Iterable[Relationship] = throw new RuntimeException
+
+ def getRelationships(direction: Direction, types: RelationshipType*): Iterable[Relationship] = throw new RuntimeException
+
+ def hasRelationship(types: RelationshipType*): Boolean = throw new RuntimeException
+
+ def hasRelationship(direction: Direction, types: RelationshipType*): Boolean = throw new RuntimeException
+
+ def getRelationships(dir: Direction): Iterable[Relationship] = throw new RuntimeException
+
+ def hasRelationship(dir: Direction): Boolean = throw new RuntimeException
+
+ def getRelationships(`type`: RelationshipType, dir: Direction): Iterable[Relationship] = {
+ val rels = n.getRelationships(`type`, dir).asScala.toList
+ afterGetRelationship(n)
+ rels.toIterable.asJava
+ }
+
+
+ def hasRelationship(`type`: RelationshipType, dir: Direction): Boolean = throw new RuntimeException
+
+ def getSingleRelationship(`type`: RelationshipType, dir: Direction): Relationship = throw new RuntimeException
+
+ def createRelationshipTo(otherNode: Node, `type`: RelationshipType): Relationship = {
+ n.createRelationshipTo(otherNode, `type`)
+ }
+
+ def traverse(traversalOrder: Order, stopEvaluator: StopEvaluator, returnableEvaluator: ReturnableEvaluator, relationshipType: RelationshipType, direction: Direction): Traverser = throw new RuntimeException
+
+ def traverse(traversalOrder: Order, stopEvaluator: StopEvaluator, returnableEvaluator: ReturnableEvaluator, firstRelationshipType: RelationshipType, firstDirection: Direction, secondRelationshipType: RelationshipType, secondDirection: Direction): Traverser = throw new RuntimeException
+
+ def traverse(traversalOrder: Order, stopEvaluator: StopEvaluator, returnableEvaluator: ReturnableEvaluator, relationshipTypesAndDirections: AnyRef*): Traverser = throw new RuntimeException
+
+ def getGraphDatabase: GraphDatabaseService = throw new RuntimeException
+
+ def hasProperty(key: String): Boolean = throw new RuntimeException
+
+ def getProperty(key: String): AnyRef = throw new RuntimeException
+
+ def getProperty(key: String, defaultValue: Any): AnyRef = throw new RuntimeException
+
+ def setProperty(key: String, value: Any) {
+ throw new RuntimeException
+ }
+
+ def removeProperty(key: String): AnyRef = throw new RuntimeException
+
+ def getPropertyKeys: Iterable[String] = throw new RuntimeException
+
+ def getPropertyValues: Iterable[AnyRef] = throw new RuntimeException
+}

0 comments on commit f424256

Please sign in to comment.