Permalink
Browse files

Made Cypher better at keeping a numeric type. Adding two integers now…

… returns an integer, and not a double
  • Loading branch information...
1 parent 0d0da14 commit 64f5ba11cfe34add64b34108526864968f981e04 @systay committed Nov 4, 2012
View
@@ -2,6 +2,7 @@
-------------------
o The traversal pattern matcher now supports variable length relationship patterns
o Fixes #946 - HAS(...) fails with ThisShouldNotHappenException for some patterns
+o Made Cypher better at keeping a numeric type. Adding two integers now returns an integer, and not a double.
1.9.M01 (2012-10-23)
--------------------
@@ -21,19 +21,18 @@ package org.neo4j.cypher.internal.commands.expressions
import org.neo4j.cypher.internal.symbols._
import org.neo4j.cypher.CypherTypeException
-import collection.Map
-import org.neo4j.cypher.internal.helpers.IsCollection
+import org.neo4j.cypher.internal.helpers.{TypeSafeMathSupport, IsCollection}
import org.neo4j.cypher.internal.pipes.ExecutionContext
-case class Add(a: Expression, b: Expression) extends Expression {
+case class Add(a: Expression, b: Expression) extends Expression with TypeSafeMathSupport {
def apply(ctx: ExecutionContext) = {
val aVal = a(ctx)
val bVal = b(ctx)
(aVal, bVal) match {
case (null, _) => null
case (_, null) => null
- case (x: Number, y: Number) => x.doubleValue() + y.doubleValue()
+ case (x: Number, y: Number) => plus(x,y)
case (x: String, y: String) => x + y
case (IsCollection(x), IsCollection(y)) => x ++ y
case (IsCollection(x), y) => x ++ Seq(y)
@@ -24,7 +24,7 @@ case class Divide(a: Expression, b: Expression) extends Arithmetics(a, b) {
def verb = "divide"
- def calc(a: Number, b: Number) = a.doubleValue() / b.doubleValue()
+ def calc(a: Number, b: Number) = divide(a, b)
def rewrite(f: (Expression) => Expression) = f(Divide(a.rewrite(f), b.rewrite(f)))
@@ -20,6 +20,7 @@
package org.neo4j.cypher.internal.commands.expressions
import org.neo4j.cypher._
+import internal.helpers.TypeSafeMathSupport
import internal.pipes.ExecutionContext
import internal.symbols._
import collection.Map
@@ -69,7 +70,7 @@ case class CachedExpression(key:String, typ:CypherType) extends Expression {
}
abstract class Arithmetics(left: Expression, right: Expression)
- extends Expression {
+ extends Expression with TypeSafeMathSupport {
def throwTypeError(bVal: Any, aVal: Any): Nothing = {
throw new CypherTypeException("Don't know how to " + this + " `" + bVal + "` with `" + aVal + "`")
}
@@ -84,7 +85,7 @@ abstract class Arithmetics(left: Expression, right: Expression)
}
}
- def calc(a: Number, b: Number): Number
+ def calc(a: Number, b: Number): Any
def filter(f: (Expression) => Boolean) = if(f(this))
Seq(this) ++ left.filter(f) ++ right.filter(f)
@@ -20,7 +20,7 @@
package org.neo4j.cypher.internal.commands.expressions
case class Multiply(a: Expression, b: Expression) extends Arithmetics(a, b) {
- def calc(a: Number, b: Number) = a.doubleValue() * b.doubleValue()
+ def calc(a: Number, b: Number) = multiply(a, b)
def rewrite(f: (Expression) => Expression) = f(Multiply(a.rewrite(f), b.rewrite(f)))
@@ -20,7 +20,7 @@
package org.neo4j.cypher.internal.commands.expressions
case class Subtract(a: Expression, b: Expression) extends Arithmetics(a, b) {
- def calc(a: Number, b: Number) = a.doubleValue() - b.doubleValue()
+ def calc(a: Number, b: Number) = minus(a, b)
def rewrite(f: (Expression) => Expression) = f(Subtract(a.rewrite(f), b.rewrite(f)))
@@ -0,0 +1,222 @@
+/**
+ * Copyright (c) 2002-2012 "Neo Technology,"
+ * Network Engine for Objects in Lund AB [http://neotechnology.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+package org.neo4j.cypher.internal.helpers
+
+trait TypeSafeMathSupport {
+ def plus(left: Any, right: Any): Any = {
+ (left, right) match {
+ case (null, _) => null
+ case (_, null) => null
+
+ case (l: Byte, r: Byte) => l + r
+ case (l: Byte, r: Double) => l + r
+ case (l: Byte, r: Float) => l + r
+ case (l: Byte, r: Int) => l + r
+ case (l: Byte, r: Long) => l + r
+ case (l: Byte, r: Short) => l + r
+
+ case (l: Double, r: Byte) => l + r
+ case (l: Double, r: Double) => l + r
+ case (l: Double, r: Float) => l + r
+ case (l: Double, r: Int) => l + r
+ case (l: Double, r: Long) => l + r
+ case (l: Double, r: Short) => l + r
+
+ case (l: Float, r: Byte) => l + r
+ case (l: Float, r: Double) => l + r
+ case (l: Float, r: Float) => l + r
+ case (l: Float, r: Int) => l + r
+ case (l: Float, r: Long) => l + r
+ case (l: Float, r: Short) => l + r
+
+ case (l: Int, r: Byte) => l + r
+ case (l: Int, r: Double) => l + r
+ case (l: Int, r: Float) => l + r
+ case (l: Int, r: Int) => l + r
+ case (l: Int, r: Long) => l + r
+ case (l: Int, r: Short) => l + r
+
+ case (l: Long, r: Byte) => l + r
+ case (l: Long, r: Double) => l + r
+ case (l: Long, r: Float) => l + r
+ case (l: Long, r: Int) => l + r
+ case (l: Long, r: Long) => l + r
+ case (l: Long, r: Short) => l + r
+
+ case (l: Short, r: Byte) => l + r
+ case (l: Short, r: Double) => l + r
+ case (l: Short, r: Float) => l + r
+ case (l: Short, r: Int) => l + r
+ case (l: Short, r: Long) => l + r
+ case (l: Short, r: Short) => l + r
+
+ }
+ }
+
+ def divide(left: Any, right: Any): Any = {
+ (left, right) match {
+ case (null, _) => null
+ case (_, null) => null
+
+ case (l: Byte, r: Byte) => l / r
+ case (l: Byte, r: Double) => l / r
+ case (l: Byte, r: Float) => l / r
+ case (l: Byte, r: Int) => l / r
+ case (l: Byte, r: Long) => l / r
+ case (l: Byte, r: Short) => l / r
+
+ case (l: Double, r: Byte) => l / r
+ case (l: Double, r: Double) => l / r
+ case (l: Double, r: Float) => l / r
+ case (l: Double, r: Int) => l / r
+ case (l: Double, r: Long) => l / r
+ case (l: Double, r: Short) => l / r
+
+ case (l: Float, r: Byte) => l / r
+ case (l: Float, r: Double) => l / r
+ case (l: Float, r: Float) => l / r
+ case (l: Float, r: Int) => l / r
+ case (l: Float, r: Long) => l / r
+ case (l: Float, r: Short) => l / r
+
+ case (l: Int, r: Byte) => l / r
+ case (l: Int, r: Double) => l / r
+ case (l: Int, r: Float) => l / r
+ case (l: Int, r: Int) => l / r
+ case (l: Int, r: Long) => l / r
+ case (l: Int, r: Short) => l / r
+
+ case (l: Long, r: Byte) => l / r
+ case (l: Long, r: Double) => l / r
+ case (l: Long, r: Float) => l / r
+ case (l: Long, r: Int) => l / r
+ case (l: Long, r: Long) => l / r
+ case (l: Long, r: Short) => l / r
+
+ case (l: Short, r: Byte) => l / r
+ case (l: Short, r: Double) => l / r
+ case (l: Short, r: Float) => l / r
+ case (l: Short, r: Int) => l / r
+ case (l: Short, r: Long) => l / r
+ case (l: Short, r: Short) => l / r
+
+ }
+ }
+
+ def minus(left: Any, right: Any): Any = {
+ (left, right) match {
+ case (null, _) => null
+ case (_, null) => null
+
+ case (l: Byte, r: Byte) => l - r
+ case (l: Byte, r: Double) => l - r
+ case (l: Byte, r: Float) => l - r
+ case (l: Byte, r: Int) => l - r
+ case (l: Byte, r: Long) => l - r
+ case (l: Byte, r: Short) => l - r
+
+ case (l: Double, r: Byte) => l - r
+ case (l: Double, r: Double) => l - r
+ case (l: Double, r: Float) => l - r
+ case (l: Double, r: Int) => l - r
+ case (l: Double, r: Long) => l - r
+ case (l: Double, r: Short) => l - r
+
+ case (l: Float, r: Byte) => l - r
+ case (l: Float, r: Double) => l - r
+ case (l: Float, r: Float) => l - r
+ case (l: Float, r: Int) => l - r
+ case (l: Float, r: Long) => l - r
+ case (l: Float, r: Short) => l - r
+
+ case (l: Int, r: Byte) => l - r
+ case (l: Int, r: Double) => l - r
+ case (l: Int, r: Float) => l - r
+ case (l: Int, r: Int) => l - r
+ case (l: Int, r: Long) => l - r
+ case (l: Int, r: Short) => l - r
+
+ case (l: Long, r: Byte) => l - r
+ case (l: Long, r: Double) => l - r
+ case (l: Long, r: Float) => l - r
+ case (l: Long, r: Int) => l - r
+ case (l: Long, r: Long) => l - r
+ case (l: Long, r: Short) => l - r
+
+ case (l: Short, r: Byte) => l - r
+ case (l: Short, r: Double) => l - r
+ case (l: Short, r: Float) => l - r
+ case (l: Short, r: Int) => l - r
+ case (l: Short, r: Long) => l - r
+ case (l: Short, r: Short) => l - r
+
+ }
+ }
+
+ def multiply(left: Any, right: Any): Any = {
+ (left, right) match {
+ case (null, _) => null
+ case (_, null) => null
+
+ case (l: Byte, r: Byte) => l * r
+ case (l: Byte, r: Double) => l * r
+ case (l: Byte, r: Float) => l * r
+ case (l: Byte, r: Int) => l * r
+ case (l: Byte, r: Long) => l * r
+ case (l: Byte, r: Short) => l * r
+
+ case (l: Double, r: Byte) => l * r
+ case (l: Double, r: Double) => l * r
+ case (l: Double, r: Float) => l * r
+ case (l: Double, r: Int) => l * r
+ case (l: Double, r: Long) => l * r
+ case (l: Double, r: Short) => l * r
+
+ case (l: Float, r: Byte) => l * r
+ case (l: Float, r: Double) => l * r
+ case (l: Float, r: Float) => l * r
+ case (l: Float, r: Int) => l * r
+ case (l: Float, r: Long) => l * r
+ case (l: Float, r: Short) => l * r
+
+ case (l: Int, r: Byte) => l * r
+ case (l: Int, r: Double) => l * r
+ case (l: Int, r: Float) => l * r
+ case (l: Int, r: Int) => l * r
+ case (l: Int, r: Long) => l * r
+ case (l: Int, r: Short) => l * r
+
+ case (l: Long, r: Byte) => l * r
+ case (l: Long, r: Double) => l * r
+ case (l: Long, r: Float) => l * r
+ case (l: Long, r: Int) => l * r
+ case (l: Long, r: Long) => l * r
+ case (l: Long, r: Short) => l * r
+
+ case (l: Short, r: Byte) => l * r
+ case (l: Short, r: Double) => l * r
+ case (l: Short, r: Float) => l * r
+ case (l: Short, r: Int) => l * r
+ case (l: Short, r: Long) => l * r
+ case (l: Short, r: Short) => l * r
+
+ }
+ }
+}
@@ -20,12 +20,12 @@
package org.neo4j.cypher.internal.pipes.aggregation
import org.neo4j.cypher.internal.commands.expressions.Expression
-import collection.Map
import org.neo4j.cypher.internal.pipes.ExecutionContext
+import org.neo4j.cypher.internal.helpers.TypeSafeMathSupport
class AvgFunction(val value: Expression)
extends AggregationFunction
- with Plus
+ with TypeSafeMathSupport
with NumericExpressionOnly {
def name = "AVG"
Oops, something went wrong.

0 comments on commit 64f5ba1

Please sign in to comment.