diff --git a/build.sbt b/build.sbt index d0304c2..45d3ab9 100644 --- a/build.sbt +++ b/build.sbt @@ -1,4 +1,4 @@ -scalaVersion := "2.11.0-RC1" +scalaVersion := "2.10.3" sourceGenerators in Compile <+= sourceManaged in Compile map { dir => def write(name: String, content: String) = { diff --git a/project/CodeGen.scala b/project/CodeGen.scala index f449e6c..adec5e7 100644 --- a/project/CodeGen.scala +++ b/project/CodeGen.scala @@ -147,71 +147,124 @@ object CodeGen { | * Copyright (C) 2012-2014 Typesafe Inc. | */""".stripMargin.trim - private def apply0MethodSpec(r: Type): String = { - val name = "apply$mc" + s"${r.code}" + "$sp" - val applyCall = s"apply();" - def body = if (r == Type.Void) applyCall else s"return (${r.ref}) $applyCall" - s""" - |default ${r.prim} $name() { - | $body - |} - |""".stripMargin.trim + private def function0SpecMethods = { + val apply = specialized("apply", function0Spec) { + case (name, List(r)) => + val applyCall = s"apply();" + def body = if (r == Type.Void) applyCall else s"return (${r.ref}) $applyCall" + s""" + |default ${r.prim} $name() { + | $body + |} + |""".stripMargin.trim + } + indent(apply) } - private def apply0SpecMethods = { + private val function0Spec = { val rs = List(Type.Void, Type.Byte, Type.Short, Type.Int, Type.Long, Type.Char, Type.Float, Type.Double, Type.Boolean) - val methods = for (r <- rs) yield apply0MethodSpec(r) - methods.map(indent).mkString("\n\n") - } - - private def apply1MethodSpec(t1: Type, r: Type): String = { - val name = "apply$mc" + s"${r.code}${t1.code}" + "$sp" - val applyCall = s"apply((T1) ((${t1.ref}) v1));" - def body = if (r == Type.Void) applyCall else s"return (${r.ref}) $applyCall" - - s""" - |default ${r.prim} $name(${t1.prim} v1) { - | $body - |} - |""".stripMargin.trim + List("R" -> rs) } - - private def apply1SpecMethods = { + private val function1Spec = { val ts = List(Type.Int, Type.Long, Type.Float, Type.Double) val rs = List(Type.Void, Type.Boolean, Type.Int, Type.Float, Type.Long, Type.Double) - val methods = for (t1 <- ts; r <- rs) yield apply1MethodSpec(t1, r) - methods.map(indent).mkString("\n\n") + List("T1" -> ts, "R" -> rs) + } + private val function2Spec = { + val ts = List(Type.Int, Type.Long, Type.Double) + val rs = List(Type.Void, Type.Boolean, Type.Int, Type.Float, Type.Long, Type.Double) + List("T1" -> ts, "T2" -> ts, "R" -> rs) } - private def apply2MethodSpec(t1: Type, t2: Type, r: Type): String = { - val name = "apply$mc" + s"${r.code}${t1.code}${t2.code}" + "$sp" - val applyCall = s"apply((T1) ((${t1.ref}) v1), (T2) ((${t2.ref}) v2));" - def body = if (r == Type.Void) applyCall else s"return (${r.ref}) $applyCall" + private def function1SpecMethods = { + val apply = specialized("apply", function1Spec) { + case (name, List(t1, r)) => + val applyCall = s"apply((T1) ((${t1.ref}) v1));" + def body = if (r == Type.Void) applyCall else s"return (${r.ref}) $applyCall" + s""" + |default ${r.prim} $name(${t1.prim} v1) { + | $body + |} + |""".stripMargin.trim + } + // andThen / compose variants are no longer needed under 2.11 (@unspecialized has been fixed), + // but harmless. With them, we can use the same artifact for 2.10 and 2.11 + val compose = specialized("compose", function1Spec) { + case (name, List(t1, r1)) => + s""" + |default scala.Function1 $name(scala.Function1 g) { + | return compose(g); + |}""".stripMargin.trim + } + val andThen = specialized("andThen", function1Spec) { + case (name, List(t1, r1)) => + s""" + |default scala.Function1 $name(scala.Function1 g) { + | return andThen(g); + |}""".stripMargin.trim + } + indent(List(apply, compose, andThen).mkString("\n\n")) + } - s""" - |default ${r.prim} $name(${t1.prim} v1, ${t2.prim} v2) { - | $body - |} - |""".stripMargin.trim + // No longer needed under 2.11 (@unspecialized has been fixed), but harmless to keep around to avoid cross-publishing this artifact. + private def function2SpecMethods = { + val apply = specialized("apply", function2Spec) { + case (name, List(t1, t2, r)) => + val applyCall = s"apply((T1) ((${t1.ref}) v1), (T2) ((${t2.ref}) v2));" + def body = if (r == Type.Void) applyCall else s"return (${r.ref}) $applyCall" + + s""" + |default ${r.prim} $name(${t1.prim} v1, ${t2.prim} v2) { + | $body + |} + |""".stripMargin.trim + } + val curried = specialized("curried", function2Spec) { + case (name, List(t1, t2, r)) => + s""" + |default scala.Function1 $name() { + | return curried(); + |}""".stripMargin.trim + } + val tupled = specialized("tupled", function2Spec) { + case (name, List(t1, t2, r)) => + s""" + |default scala.Function1 $name() { + | return tupled(); + |}""".stripMargin.trim + } + indent(List(apply, curried, tupled).mkString("\n\n")) } - private def apply2SpecMethods = { - val ts = List(Type.Int, Type.Long, Type.Double) - val rs = List(Type.Void, Type.Boolean, Type.Int, Type.Float, Type.Long, Type.Double) - val methods = for (t1 <- ts; t2 <- ts; r <- rs) yield apply2MethodSpec(t1, t2, r) - methods.map(indent).mkString("\n\n") + private def specialized(name: String, tps: List[(String, List[Type])])(f: (String, List[Type]) => String): String = { + val tparamNames = tps.map(_._1) + def code(tps: List[Type]) = { + val sorted = (tps zip tparamNames).sortBy(_._2).map(_._1) // as per scalac, sort by tparam name before assembling the code + sorted.map(_.code).mkString + } + val ms = for { + variantTypes <- crossProduct(tps.map(_._2)) + specName = name + "$mc" + code(variantTypes) + "$sp" + } yield f(specName, variantTypes) + ms.mkString("\n") + } + + def crossProduct[A](input: List[List[A]]): List[List[A]] = input match { + case Nil => Nil + case head :: Nil => head.map(_ :: Nil) + case head :: tail => for (elem <- head; sub <- crossProduct(tail)) yield elem :: sub } def fN(n: Int) = { val header = arity(n).fHeader - val applyMethods = n match { - case 0 => apply0SpecMethods - case 1 => apply1SpecMethods - case 2 => apply2SpecMethods + val specializedVariants = n match { + case 0 => function0SpecMethods + case 1 => function1SpecMethods + case 2 => function2SpecMethods case x => "" } - val trailer = "}\n" - List(header, applyMethods, trailer).mkString + val trailer = "\n}\n" + List(header, specializedVariants, trailer).mkString } def pN(n: Int) = arity(n).pN