diff --git a/plugin/src/main/scala/dslparadise/package.scala b/plugin/src/main/scala/dslparadise/package.scala index d7e3450..53b7481 100644 --- a/plugin/src/main/scala/dslparadise/package.scala +++ b/plugin/src/main/scala/dslparadise/package.scala @@ -1,7 +1,7 @@ package object dslparadise { type `implicit =>`[-T, +R] = T => R -// type `import._ =>`[-T, +R] = T => R -// -// type `import._`[+T, +I] = T + type `import._ =>`[-T, +R] = T => R + + type `import._`[+T, I] = T } diff --git a/plugin/src/main/scala/dslparadise/typechecker/Typers.scala b/plugin/src/main/scala/dslparadise/typechecker/Typers.scala index b0b61a0..e282d1d 100644 --- a/plugin/src/main/scala/dslparadise/typechecker/Typers.scala +++ b/plugin/src/main/scala/dslparadise/typechecker/Typers.scala @@ -3,7 +3,6 @@ package typechecker import scala.reflect.NameTransformer import scala.reflect.internal.Mode -import dslparadise._ trait Typers { self: Analyzer => @@ -11,37 +10,65 @@ trait Typers { import global._ trait ParadiseTyper extends Typer with TyperContextErrors { - override def typedArg(arg: Tree, mode: Mode, newmode: Mode, pt: Type): Tree = { -// val pre = typeOf[dslparadise.`package`.type] + // rewriting rules for the DSL Paradise types + val rewritings = Map( + "dslparadise.implicit =>" -> { (arg: Tree, pt: Type) => + q"{ implicit $$bang => $arg }" + }, - pt match { -// case TypeRef(`pre`, sym, _) if sym.name.decodedName.toString == "implicit =>" => - case TypeRef(_, sym, _) - if NameTransformer.decode(sym.fullName) == "dslparadise.implicit =>" => + "dslparadise.import._ =>" -> { (arg: Tree, pt: Type) => + q"{ $$bang => import $$bang._; $arg }" + }, - val convertArg = context inSilentMode { - super.typedArg(arg.duplicate, mode, newmode, pt) - context.reporter.hasErrors - } + "dslparadise.import._" -> { (arg: Tree, pt: Type) => + q"{ import ${pt.typeArgs(1).typeSymbol.companionSymbol}._; $arg }" + } + ) - val newarg = if (convertArg) { - val newarg = q"{ implicit! => $arg }" + override def typedArg(arg: Tree, mode: Mode, newmode: Mode, pt: Type): Tree = { + val newarg = pt match { + case TypeRef(_, sym, _) => + // find rewriting rule for the expected type + rewritings get (NameTransformer decode sym.fullName) map { rewrite => - val keepArg = context inSilentMode { - super.typedArg(newarg.duplicate, mode, newmode, pt) - context.reporter.errors exists { _.errPos == NoPosition } + // only rewrite argument if it does not compile in its current form + val rewriteArg = context inSilentMode { + super.typedArg(arg.duplicate, mode, newmode, pt) + context.reporter.hasErrors } - if (keepArg) arg else newarg - } - else - arg + if (rewriteArg) { + // apply rewriting rule + val newarg = rewrite(arg, pt) - super.typedArg(newarg, mode, newmode, pt) + // to improve compiler-issued error messages, keep the original + // (non-rewritten) argument if the new (rewritten) argument + // produces compile errors + // - that have no corresponding position in the source file + // (i.e. the position is within the code that was introduced + // by the rewriting) or + // - whose message is "missing parameter type", which could be + // misleading if the rewriting introduced a function and the + // original code already had function type + val keepArg = context inSilentMode { + super.typedArg(newarg.duplicate, mode, newmode, pt) + context.reporter.errors exists { e => + e.errPos == NoPosition || e.errMsg == "missing parameter type" + } + } + + if (keepArg) arg else newarg + } + else + arg + + } getOrElse arg case _ => - super.typedArg(arg, mode, newmode, pt) + arg } + + super.typedArg(newarg, mode, newmode, pt) } } }