Skip to content

Commit

Permalink
Change closure handling
Browse files Browse the repository at this point in the history
Constrain closure parameters and result from expected type before rechecking the closure's
body. This gives more precise types and avoids the spurious duplication of some
variables.

It also avoids the unmotivated special case that we needed before to make tests pass.
  • Loading branch information
odersky committed Aug 11, 2023
1 parent 6339276 commit c4aefa1
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 105 deletions.
83 changes: 27 additions & 56 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,16 @@ class CheckCaptures extends Recheck, SymTransformer:
else if meth == defn.Caps_unsafeUnbox then
mapArgUsing(_.forceBoxStatus(false))
else if meth == defn.Caps_unsafeBoxFunArg then
mapArgUsing:
def forceBox(tp: Type): Type = tp match
case defn.FunctionOf(paramtpe :: Nil, restpe, isContextual) =>
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContextual)

case tp @ RefinedType(parent, rname, rinfo: MethodType) =>
tp.derivedRefinedType(parent, rname,
rinfo.derivedLambdaType(
paramInfos = rinfo.paramInfos.map(_.forceBoxStatus(true))))
case tp @ CapturingType(parent, refs) =>
tp.derivedCapturingType(forceBox(parent), refs)
mapArgUsing(forceBox)
else
super.recheckApply(tree, pt) match
case appType @ CapturingType(appType1, refs) =>
Expand Down Expand Up @@ -485,63 +491,28 @@ class CheckCaptures extends Recheck, SymTransformer:
else ownType
end instantiate

override def recheckClosure(tree: Closure, pt: Type)(using Context): Type =
override def recheckClosure(tree: Closure, pt: Type, forceDependent: Boolean)(using Context): Type =
val cs = capturedVars(tree.meth.symbol)
capt.println(i"typing closure $tree with cvs $cs")
super.recheckClosure(tree, pt).capturing(cs)
.showing(i"rechecked $tree / $pt = $result", capt)

/** Additionally to normal processing, update types of closures if the expected type
* is a function with only pure parameters. In that case, make the anonymous function
* also have the same parameters as the prototype.
* TODO: Develop a clearer rationale for this.
* TODO: Can we generalize this to arbitrary parameters?
* Currently some tests fail if we do this. (e.g. neg.../stackAlloc.scala, others)
*/
override def recheckBlock(block: Block, pt: Type)(using Context): Type =
block match
case closureDef(mdef) =>
pt.dealias match
case defn.FunctionOf(ptformals, _, _)
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
// Redo setup of the anonymous function so that formal parameters don't
// get capture sets. This is important to avoid false widenings to `cap`
// when taking the base type of the actual closures's dependent function
// type so that it conforms to the expected non-dependent function type.
// See withLogFile.scala for a test case.
val meth = mdef.symbol
// First, undo the previous setup which installed a completer for `meth`.
atPhase(preRecheckPhase.prev)(meth.denot.copySymDenotation())
.installAfter(preRecheckPhase)

// Next, update all parameter symbols to match expected formals
meth.paramSymss.head.lazyZip(ptformals).foreach: (psym, pformal) =>
psym.updateInfoBetween(preRecheckPhase, thisPhase, pformal.mapExprType)

// Next, update types of parameter ValDefs
mdef.paramss.head.lazyZip(ptformals).foreach: (param, pformal) =>
val ValDef(_, tpt, _) = param: @unchecked
tpt.rememberTypeAlways(pformal)

// Next, install a new completer reflecting the new parameters for the anonymous method
val mt = meth.info.asInstanceOf[MethodType]
val completer = new LazyType:
def complete(denot: SymDenotation)(using Context) =
denot.info = mt.companion(ptformals, mdef.tpt.knownType)
.showing(i"simplify info of $meth to $result", capt)
recheckDef(mdef, meth)
meth.updateInfoBetween(preRecheckPhase, thisPhase, completer)
case _ =>
mdef.rhs match
case rhs @ closure(_, _, _) =>
// In a curried closure `x => y => e` don't leak capabilities retained by
// the second closure `y => e` into the first one. This is an approximation
// of the CC rule which says that a closure contributes captures to its
// environment only if a let-bound reference to the closure is used.
mdef.rhs.putAttachment(ClosureBodyValue, ())
case _ =>
super.recheckClosure(tree, pt, forceDependent).capturing(cs)
.showing(i"rechecked closure $tree / $pt = $result", capt)

override def recheckClosureBlock(mdef: DefDef, expr: Closure, pt: Type)(using Context): Type =
mdef.rhs match
case rhs @ closure(_, _, _) =>
// In a curried closure `x => y => e` don't leak capabilities retained by
// the second closure `y => e` into the first one. This is an approximation
// of the CC rule which says that a closure contributes captures to its
// environment only if a let-bound reference to the closure is used.
mdef.rhs.putAttachment(ClosureBodyValue, ())
case _ =>
super.recheckBlock(block, pt)

// Constrain closure's parameters and result from the expected type before
// rechecking the body.
val res = recheckClosure(expr, pt, forceDependent = true)
recheckDef(mdef, mdef.symbol)
res
end recheckClosureBlock

override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit =
try
Expand Down
16 changes: 11 additions & 5 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,17 @@ extends tpd.TreeTraverser:
val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil)
.showing(i"update info $sym: ${sym.info} --> $result", capt)
if newInfo ne sym.info then
val completer = new LazyType:
def complete(denot: SymDenotation)(using Context) =
denot.info = newInfo
recheckDef(tree, sym)
updateInfo(sym, completer)
updateInfo(sym,
if sym.isAnonymousFunction then
// closures are handled specially; the newInfo is constrained from
// the expected type and only afterwards we recheck the definition
newInfo
else new LazyType:
def complete(denot: SymDenotation)(using Context) =
// infos other methods are determined from their definitions which
// are checked on depand
denot.info = newInfo
recheckDef(tree, sym))
case tree: Bind =>
val sym = tree.symbol
updateInfo(sym, transformInferredType(sym.info, boxed = false))
Expand Down
99 changes: 60 additions & 39 deletions compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,18 @@ object Recheck:
*/
def updateInfoBetween(prevPhase: DenotTransformer, lastPhase: DenotTransformer, newInfo: Type)(using Context): Unit =
if sym.info ne newInfo then
val flags = sym.flags
sym.copySymDenotation(
initFlags =
if sym.flags.isAllOf(ResetPrivateParamAccessor)
then sym.flags &~ ResetPrivate | Private
else sym.flags
if flags.isAllOf(ResetPrivateParamAccessor)
then flags &~ ResetPrivate | Private
else flags
).installAfter(lastPhase) // reset
sym.copySymDenotation(
info = newInfo,
initFlags =
if newInfo.isInstanceOf[LazyType] then sym.flags &~ Touched
else sym.flags
if newInfo.isInstanceOf[LazyType] then flags &~ Touched
else flags
).installAfter(prevPhase)

/** Does symbol have a new denotation valid from phase.next that is different
Expand Down Expand Up @@ -96,17 +97,44 @@ object Recheck:
case Some(tpe) => tree.withType(tpe).asInstanceOf[T]
case None => tree

extension (tpe: Type)

/** Map ExprType => T to () ?=> T (and analogously for pure versions).
* Even though this phase runs after ElimByName, ExprTypes can still occur
* as by-name arguments of applied types. See note in doc comment for
* ElimByName phase. Test case is bynamefun.scala.
*/
def mapExprType(using Context): Type = tpe match
case ExprType(rt) => defn.ByNameFunction(rt)
case _ => tpe

/** Map ExprType => T to () ?=> T (and analogously for pure versions).
* Even though this phase runs after ElimByName, ExprTypes can still occur
* as by-name arguments of applied types. See note in doc comment for
* ElimByName phase. Test case is bynamefun.scala.
*/
private def mapExprType(tp: Type)(using Context): Type = tp match
case ExprType(rt) => defn.ByNameFunction(rt)
case _ => tp

/** Normalize `=> A` types to `() ?=> A` types
* - at the top level
* - in function and method parameter types
* - under annotations
*/
def normalizeByName(tp: Type)(using Context): Type = tp match
case tp: ExprType =>
mapExprType(tp)
case tp: PolyType =>
tp.derivedLambdaType(resType = normalizeByName(tp.resType))
case tp: MethodType =>
tp.derivedLambdaType(
paramInfos = tp.paramInfos.mapConserve(mapExprType),
resType = normalizeByName(tp.resType))
case tp @ RefinedType(parent, nme.apply, rinfo) if defn.isFunctionType(tp) =>
tp.derivedRefinedType(parent, nme.apply, normalizeByName(rinfo))
case tp @ defn.FunctionOf(pformals, restpe, isContextual) =>
val pformals1 = pformals.mapConserve(mapExprType)
val restpe1 = normalizeByName(restpe)
if (pformals1 ne pformals) || (restpe1 ne restpe) then
defn.FunctionOf(pformals1, restpe1, isContextual)
else
tp
case tp @ AnnotatedType(parent, ann) =>
tp.derivedAnnotatedType(normalizeByName(parent), ann)
case _ =>
tp
end Recheck

/** A base class that runs a simplified typer pass over an already re-typed program. The pass
* does not transform trees but returns instead the re-typed type of each tree as it is
Expand Down Expand Up @@ -183,27 +211,16 @@ abstract class Recheck extends Phase, SymTransformer:
else AnySelectionProto
recheckSelection(tree, recheck(qual, proto).widenIfUnstable, name, pt)

/** When we select the `apply` of a function with type such as `(=> A) => B`,
* we need to convert the parameter type `=> A` to `() ?=> A`. See doc comment
* of `mapExprType`.
*/
def normalizeByName(mbr: SingleDenotation)(using Context): SingleDenotation = mbr.info match
case mt: MethodType if mt.paramInfos.exists(_.isInstanceOf[ExprType]) =>
mbr.derivedSingleDenotation(mbr.symbol,
mt.derivedLambdaType(paramInfos = mt.paramInfos.map(_.mapExprType)))
case _ =>
mbr

def recheckSelection(tree: Select, qualType: Type, name: Name,
sharpen: Denotation => Denotation)(using Context): Type =
if name.is(OuterSelectName) then tree.tpe
else
//val pre = ta.maybeSkolemizePrefix(qualType, name)
val mbr = normalizeByName(
val mbr =
sharpen(
qualType.findMember(name, qualType,
excluded = if tree.symbol.is(Private) then EmptyFlags else Private
)).suchThat(tree.symbol == _))
)).suchThat(tree.symbol == _)
val newType = tree.tpe match
case prevType: NamedType =>
val prevDenot = prevType.denot
Expand Down Expand Up @@ -281,7 +298,7 @@ abstract class Recheck extends Phase, SymTransformer:
else fntpe.paramInfos
def recheckArgs(args: List[Tree], formals: List[Type], prefs: List[ParamRef]): List[Type] = args match
case arg :: args1 =>
val argType = recheck(arg, formals.head.mapExprType)
val argType = recheck(arg, normalizeByName(formals.head))
val formals1 =
if fntpe.isParamDependent
then formals.tail.map(_.substParam(prefs.head, argType))
Expand Down Expand Up @@ -313,27 +330,33 @@ abstract class Recheck extends Phase, SymTransformer:
recheck(tree.rhs, lhsType.widen)
defn.UnitType

def recheckBlock(stats: List[Tree], expr: Tree, pt: Type)(using Context): Type =
private def recheckBlock(stats: List[Tree], expr: Tree)(using Context): Type =
recheckStats(stats)
val exprType = recheck(expr)
TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm))

def recheckBlock(tree: Block, pt: Type)(using Context): Type = tree match
case Block(Nil, expr: Block) => recheckBlock(expr, pt)
case Block((mdef : DefDef) :: Nil, closure: Closure) =>
recheckClosureBlock(mdef, closure.withSpan(tree.span), pt)
case Block(stats, expr) => recheckBlock(stats, expr)
// The expected type `pt` is not propagated. Doing so would allow variables in the
// expected type to contain references to local symbols of the block, so the
// local symbols could escape that way.
TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm))

def recheckBlock(tree: Block, pt: Type)(using Context): Type =
recheckBlock(tree.stats, tree.expr, pt)
def recheckClosureBlock(mdef: DefDef, expr: Closure, pt: Type)(using Context): Type =
recheckBlock(mdef :: Nil, expr)

def recheckInlined(tree: Inlined, pt: Type)(using Context): Type =
recheckBlock(tree.bindings, tree.expansion, pt)(using inlineContext(tree))
recheckBlock(tree.bindings, tree.expansion)(using inlineContext(tree))

def recheckIf(tree: If, pt: Type)(using Context): Type =
recheck(tree.cond, defn.BooleanType)
recheck(tree.thenp, pt) | recheck(tree.elsep, pt)

def recheckClosure(tree: Closure, pt: Type)(using Context): Type =
def recheckClosure(tree: Closure, pt: Type, forceDependent: Boolean = false)(using Context): Type =
if tree.tpt.isEmpty then
tree.meth.tpe.widen.toFunctionType(tree.meth.symbol.is(JavaDefined))
tree.meth.tpe.widen.toFunctionType(tree.meth.symbol.is(JavaDefined), alwaysDependent = forceDependent)
else
recheck(tree.tpt)

Expand Down Expand Up @@ -534,9 +557,7 @@ abstract class Recheck extends Phase, SymTransformer:

/** Check that widened types of `tpe` and `pt` are compatible. */
def checkConforms(tpe: Type, pt: Type, tree: Tree)(using Context): Unit = tree match
case _: DefTree | EmptyTree | _: TypeTree | _: Closure =>
// Don't report closure nodes, since their span is a point; wait instead
// for enclosing block to preduce an error
case _: DefTree | EmptyTree | _: TypeTree =>
case _ =>
checkConformsExpr(tpe.widenExpr, pt.widenExpr, tree)

Expand Down
2 changes: 1 addition & 1 deletion tests/neg-custom-args/captures/capt1.check
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:14:2 -----------------------------------------
14 | def f(y: Int) = if x == null then y else y // error
| ^
| Found: Int ->{x} Int
| Found: (y: Int) ->{x} Int
| Required: Matchable
15 | f
|
Expand Down
2 changes: 1 addition & 1 deletion tests/neg-custom-args/captures/try.check
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
| This is often caused by a local capability in an argument of method handle
| leaking as part of its result.
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try.scala:29:43 ------------------------------------------
29 | val b = handle[Exception, () -> Nothing] { // error
29 | val b = handle[Exception, () -> Nothing] { // error
| ^
| Found: (x: CT[Exception]^) ->? () ->{x} Nothing
| Required: (x$0: CanThrow[Exception]) => () -> Nothing
Expand Down
2 changes: 1 addition & 1 deletion tests/neg-custom-args/captures/try.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test =
(ex: Exception) => ???
}

val b = handle[Exception, () -> Nothing] { // error
val b = handle[Exception, () -> Nothing] { // error
(x: CanThrow[Exception]) => () => raise(new Exception)(using x)
} {
(ex: Exception) => ???
Expand Down
7 changes: 5 additions & 2 deletions tests/pos-custom-args/captures/bynamefun.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
object test:
class Plan(elem: Plan)
object SomePlan extends Plan(???)
type PP = (-> Plan) -> Plan
def f1(expr: (-> Plan) -> Plan): Plan = expr(SomePlan)
f1 { onf => Plan(onf) }
def f2(expr: (=> Plan) -> Plan): Plan = ???
f2 { onf => Plan(onf) }
def f3(expr: (-> Plan) => Plan): Plan = ???
f1 { onf => Plan(onf) }
f3 { onf => Plan(onf) }
def f4(expr: (=> Plan) => Plan): Plan = ???
f2 { onf => Plan(onf) }
f4 { onf => Plan(onf) }
def f5(expr: PP): Plan = expr(SomePlan)
f5 { onf => Plan(onf) }

0 comments on commit c4aefa1

Please sign in to comment.