Skip to content

Commit 815f60f

Browse files
committed
Refine equality of Constant types over floating point values.
The constant types for 0d and -0d should not be equal. This is implemented by checking equality of the result of doubleToRawLongBits / floatToRawIntBits, which also correctly considers two NaNs of the same flavour to be equal. Followup to SI-6331.
1 parent c619f94 commit 815f60f

File tree

3 files changed

+96
-2
lines changed

3 files changed

+96
-2
lines changed

src/reflect/scala/reflect/internal/Constants.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,16 @@ trait Constants extends api.Constants {
8383
*/
8484
override def equals(other: Any): Boolean = other match {
8585
case that: Constant =>
86-
this.tag == that.tag &&
87-
(this.value == that.value || this.isNaN && that.isNaN)
86+
// Consider two NaNs to be identical, despite non-equality
87+
// Consider -0d to be distinct from 0d, despite equality
88+
import java.lang.Double.doubleToRawLongBits
89+
import java.lang.Float.floatToRawIntBits
90+
91+
this.tag == that.tag && ((value, that.value) match {
92+
case (f1: Float, f2: Float) => floatToRawIntBits(f1) == floatToRawIntBits(f2)
93+
case (d1: Double, d2: Double) => doubleToRawLongBits(d1) == doubleToRawLongBits(d2)
94+
case (v1, v2) => v1 == v2
95+
})
8896
case _ => false
8997
}
9098

test/files/run/t6331.check

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
() == ()
2+
true == true
3+
true != false
4+
false != true
5+
0.toByte == 0.toByte
6+
0.toByte != 1.toByte
7+
0.toShort == 0.toShort
8+
0.toShort != 1.toShort
9+
0 == 0
10+
0 != 1
11+
0L == 0L
12+
0L != 1L
13+
0.0f == 0.0f
14+
0.0f != -0.0f
15+
-0.0f != 0.0f
16+
NaNf == NaNf
17+
0.0d == 0.0d
18+
0.0d != -0.0d
19+
-0.0d != 0.0d
20+
NaNd == NaNd
21+
0 != 0.0d
22+
0 != 0L
23+
0.0d != 0.0f

test/files/run/t6331.scala

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import scala.tools.partest._
2+
import java.io._
3+
import scala.tools.nsc._
4+
import scala.tools.nsc.util.CommandLineParser
5+
import scala.tools.nsc.{Global, Settings, CompilerCommand}
6+
import scala.tools.nsc.reporters.ConsoleReporter
7+
8+
// Test of Constant#equals, which must must account for floating point intricacies.
9+
object Test extends DirectTest {
10+
11+
override def code = ""
12+
13+
override def show() {
14+
val global = newCompiler()
15+
import global._
16+
17+
def check(c1: Any, c2: Any): Unit = {
18+
val equal = Constant(c1) == Constant(c2)
19+
def show(a: Any) = "" + a + (a match {
20+
case _: Byte => ".toByte"
21+
case _: Short => ".toShort"
22+
case _: Long => "L"
23+
case _: Float => "f"
24+
case _: Double => "d"
25+
case _ => ""
26+
})
27+
val op = if (equal) "==" else "!="
28+
println(f"${show(c1)}%12s $op ${show(c2)}")
29+
}
30+
31+
check((), ())
32+
33+
check(true, true)
34+
check(true, false)
35+
check(false, true)
36+
37+
check(0.toByte, 0.toByte)
38+
check(0.toByte, 1.toByte)
39+
40+
check(0.toShort, 0.toShort)
41+
check(0.toShort, 1.toShort)
42+
43+
check(0, 0)
44+
check(0, 1)
45+
46+
check(0L, 0L)
47+
check(0L, 1L)
48+
49+
check(0f, 0f)
50+
check(0f, -0f)
51+
check(-0f, 0f)
52+
check(Float.NaN, Float.NaN)
53+
54+
check(0d, 0d)
55+
check(0d, -0d)
56+
check(-0d, 0d)
57+
check(Double.NaN, Double.NaN)
58+
59+
check(0, 0d)
60+
check(0, 0L)
61+
check(0d, 0f)
62+
}
63+
}

0 commit comments

Comments
 (0)