Skip to content

Commit

Permalink
Allows case classes as value classes
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Mar 7, 2012
1 parent e9a1207 commit 54e284d
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 12 deletions.
26 changes: 16 additions & 10 deletions src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ trait SyntheticMethods extends ast.TreeDSL {
List(
Product_productPrefix -> (() => constantNullary(nme.productPrefix, clazz.name.decode)),
Product_productArity -> (() => constantNullary(nme.productArity, arity)),
Product_productElement -> (() => perElementMethod(nme.productElement, accessorLub)(Ident)),
Product_productElement -> (() => perElementMethod(nme.productElement, accessorLub)(Select(This(clazz), _))),
Product_iterator -> (() => productIteratorMethod),
Product_canEqual -> (() => canEqualMethod)
// This is disabled pending a reimplementation which doesn't add any
Expand All @@ -226,24 +226,28 @@ trait SyntheticMethods extends ast.TreeDSL {
)
}

def valueClassMethods = List(
Any_hashCode -> (() => hashCodeDerivedValueClassMethod),
Any_equals -> (() => equalsDerivedValueClassMethod)
)

def caseClassMethods = productMethods ++ productNMethods ++ Seq(
Object_hashCode -> (() => forwardToRuntime(Object_hashCode)),
Object_toString -> (() => forwardToRuntime(Object_toString)),
Object_equals -> (() => equalsCaseClassMethod)
)

def valueCaseClassMethods = productMethods ++ productNMethods ++ valueClassMethods ++ Seq(
Any_toString -> (() => forwardToRuntime(Object_toString))
)

def caseObjectMethods = productMethods ++ Seq(
Object_hashCode -> (() => constantMethod(nme.hashCode_, clazz.name.decode.hashCode)),
Object_toString -> (() => constantMethod(nme.toString_, clazz.name.decode))
// Not needed, as reference equality is the default.
// Object_equals -> (() => createMethod(Object_equals)(m => This(clazz) ANY_EQ Ident(m.firstParam)))
)

def inlineClassMethods = List(
Any_hashCode -> (() => hashCodeDerivedValueClassMethod),
Any_equals -> (() => equalsDerivedValueClassMethod)
)

/** If you serialize a singleton and then deserialize it twice,
* you will have two instances of your singleton unless you implement
* readResolve. Here it is implemented for all objects which have
Expand All @@ -258,10 +262,12 @@ trait SyntheticMethods extends ast.TreeDSL {

def synthesize(): List[Tree] = {
val methods = (
if (clazz.isDerivedValueClass) inlineClassMethods
else if (!clazz.isCase) Nil
else if (clazz.isModuleClass) caseObjectMethods
else caseClassMethods
if (clazz.isCase)
if (clazz.isDerivedValueClass) valueCaseClassMethods
else if (clazz.isModuleClass) caseObjectMethods
else caseClassMethods
else if (clazz.isDerivedValueClass) valueClassMethods
else Nil
)

def impls = for ((m, impl) <- methods ; if !hasOverridingImplementation(m)) yield impl()
Expand Down
2 changes: 1 addition & 1 deletion src/library/scala/Serializable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ package scala
/**
* Classes extending this trait are serializable across platforms (Java, .NET).
*/
trait Serializable extends java.io.Serializable
trait Serializable extends Any with java.io.Serializable
2 changes: 1 addition & 1 deletion test/files/pos/t715/meredith_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.sap.dspace.model.othello;
import scala.xml._

trait XMLRenderer {
type T <: {def getClass() : java.lang.Class[_]}
type T <: Any {def getClass() : java.lang.Class[_]}
val valueTypes =
List(
classOf[java.lang.Boolean],
Expand Down
21 changes: 21 additions & 0 deletions test/files/run/MeterCaseClass.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
2.0
Meter(4.0)
false
x.isInstanceOf[Meter]: true
x.hashCode: 1072693248
x == 1: false
x == y: true
a == b: true
testing native arrays
Array(Meter(1.0), Meter(2.0))
Meter(1.0)
>>>Meter(1.0)<<< Meter(1.0)
>>>Meter(2.0)<<< Meter(2.0)
testing wrapped arrays
FlatArray(Meter(1.0), Meter(2.0))
Meter(1.0)
>>>Meter(1.0)<<< Meter(1.0)
>>>Meter(2.0)<<< Meter(2.0)
FlatArray(Meter(2.0), Meter(3.0))
ArrayBuffer(1.0, 2.0)
FlatArray(0.3048ft, 0.6096ft)
99 changes: 99 additions & 0 deletions test/files/run/MeterCaseClass.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package a {
case class Meter(underlying: Double) extends AnyVal with _root_.b.Printable {
def + (other: Meter): Meter =
new Meter(this.underlying + other.underlying)
def / (other: Meter): Double = this.underlying / other.underlying
def / (factor: Double): Meter = new Meter(this.underlying / factor)
def < (other: Meter): Boolean = this.underlying < other.underlying
def toFoot: Foot = new Foot(this.underlying * 0.3048)
override def print = { Console.print(">>>"); super.print; proprint }
}

object Meter extends (Double => Meter) {

implicit val boxings = new BoxingConversions[Meter, Double] {
def box(x: Double) = new Meter(x)
def unbox(m: Meter) = m.underlying
}
}

class Foot(val unbox: Double) extends AnyVal {
def + (other: Foot): Foot =
new Foot(this.unbox + other.unbox)
override def toString = unbox.toString+"ft"
}
object Foot {
implicit val boxings = new BoxingConversions[Foot, Double] {
def box(x: Double) = new Foot(x)
def unbox(m: Foot) = m.unbox
}
}

}
package b {
trait Printable extends Any {
def print: Unit = Console.print(this)
protected def proprint = Console.print("<<<")
}
}
import a._
import _root_.b._
object Test extends App {

{
val x: Meter = new Meter(1)
val a: Object = x.asInstanceOf[Object]
val y: Meter = a.asInstanceOf[Meter]

val u: Double = 1
val b: Object = u.asInstanceOf[Object]
val v: Double = b.asInstanceOf[Double]
}

val x = new Meter(1)
val y = x
println((x + x) / x)
println((x + x) / 0.5)
println((x < x).toString)
println("x.isInstanceOf[Meter]: "+x.isInstanceOf[Meter])


println("x.hashCode: "+x.hashCode)
println("x == 1: "+(x == 1))
println("x == y: "+(x == y))
assert(x.hashCode == (1.0).hashCode)

val a: Any = x
val b: Any = y
println("a == b: "+(a == b))

{ println("testing native arrays")
val arr = Array(x, y + x)
println(arr.deep)
def foo[T <: Printable](x: Array[T]) {
for (i <- 0 until x.length) { x(i).print; println(" "+x(i)) }
}
val m = arr(0)
println(m)
foo(arr)
}

{ println("testing wrapped arrays")
import collection.mutable.FlatArray
val arr = FlatArray(x, y + x)
println(arr)
def foo(x: FlatArray[Meter]) {
for (i <- 0 until x.length) { x(i).print; println(" "+x(i)) }
}
val m = arr(0)
println(m)
foo(arr)
val ys: Seq[Meter] = arr map (_ + new Meter(1))
println(ys)
val zs = arr map (_ / Meter(1))
println(zs)
val fs = arr map (_.toFoot)
println(fs)
}

}

0 comments on commit 54e284d

Please sign in to comment.