Skip to content

Commit

Permalink
feat: Compute transitive closure of upvars
Browse files Browse the repository at this point in the history
Instead of computing transitive closure of function calls.

10x faster! (500ms -> 50ms in self compilation)
  • Loading branch information
vain0x committed Jan 1, 2021
1 parent be84295 commit 4671606
Showing 1 changed file with 68 additions and 104 deletions.
172 changes: 68 additions & 104 deletions MiloneLang/ClosureConversion.fs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
/// By traversing over the expression, calculates the following sets
/// (call it known context here) at every point.
///
/// - Known: set of functions defined outside the current function.
/// - Locals: set of variables defined inside the current function.
/// - Refs: set of variables or functions that occurs in the current function.
///
Expand All @@ -44,7 +43,7 @@
/// ## Transformation
///
/// Here the set of captured variables can easily compute
/// as `(Refs \ Locals) \ Known` using known context of functions,
/// as `(Refs \ Locals)` using known context of functions,
/// where `s \ t` is set difference.
///
/// Three rules:
Expand Down Expand Up @@ -94,14 +93,12 @@ open MiloneLang.Hir
[<RequireQualifiedAccess>]
[<NoEquality; NoComparison>]
type private KnownCtx =
{ Known: AssocSet<FunSerial>
Locals: AssocSet<VarSerial>
{ Locals: AssocSet<VarSerial>
UseVars: AssocSet<VarSerial>
UseFuns: AssocSet<FunSerial> }

let private knownCtxEmpty (): KnownCtx =
{ Known = setEmpty funSerialCmp
Locals = setEmpty varSerialCmp
{ Locals = setEmpty varSerialCmp
UseVars = setEmpty varSerialCmp
UseFuns = setEmpty funSerialCmp }

Expand All @@ -117,10 +114,6 @@ let private knownCtxLeaveFunDecl (baseCtx: KnownCtx) (ctx: KnownCtx) =
UseFuns = baseCtx.UseFuns
Locals = baseCtx.Locals }

let private knownCtxAddKnown serial (ctx: KnownCtx) =
{ ctx with
Known = ctx.Known |> setAdd serial }

let private knownCtxAddLocal serial (ctx: KnownCtx) =
{ ctx with
Locals = ctx.Locals |> setAdd serial }
Expand All @@ -133,28 +126,15 @@ let private knownCtxUseFun funSerial (ctx: KnownCtx) =
{ ctx with
UseFuns = ctx.UseFuns |> setAdd funSerial }

/// Returns serials referenced in the current context but not known nor locals
/// including function serials.
let private knownCtxToNonlocalUses (ctx: KnownCtx): VarSerial list * FunSerial list =
let vars =
ctx.UseVars
|> setFold
(fun acc varSerial ->
if ctx.Locals |> setContains varSerial |> not
then varSerial :: acc
else acc)
[]

let funs =
ctx.UseFuns
|> setFold
(fun acc funSerial ->
if ctx.Known |> setContains funSerial |> not
then funSerial :: acc
else acc)
[]

vars, funs
let private knownCtxToNonlocalVars (ctx: KnownCtx): AssocSet<VarSerial> =
ctx.UseVars
|> setFold
(fun acc varSerial ->
if ctx.Locals |> setContains varSerial |> not then
acc |> setAdd varSerial
else
acc)
(setEmpty varSerialCmp)

// -----------------------------------------------
// Caps
Expand Down Expand Up @@ -208,7 +188,8 @@ type private CcCtx =
Funs: AssocMap<FunSerial, FunDef>
Tys: AssocMap<TySerial, TyDef>
Current: KnownCtx
FunKnowns: AssocMap<FunSerial, KnownCtx> }
FunKnowns: AssocMap<FunSerial, KnownCtx>
FunUpvars: AssocMap<FunSerial, AssocSet<VarSerial>> }

let private ofTyCtx (tyCtx: TyCtx): CcCtx =
{ Serial = tyCtx.Serial
Expand All @@ -217,7 +198,8 @@ let private ofTyCtx (tyCtx: TyCtx): CcCtx =
Tys = tyCtx.Tys

Current = knownCtxEmpty ()
FunKnowns = mapEmpty funSerialCmp }
FunKnowns = mapEmpty funSerialCmp
FunUpvars = mapEmpty funSerialCmp }

let private toTyCtx (tyCtx: TyCtx) (ctx: CcCtx) =
{ tyCtx with
Expand All @@ -226,10 +208,6 @@ let private toTyCtx (tyCtx: TyCtx) (ctx: CcCtx) =
Funs = ctx.Funs
Tys = ctx.Tys }

let private addKnown funSerial (ctx: CcCtx) =
{ ctx with
Current = ctx.Current |> knownCtxAddKnown funSerial }

let private addLocal varSerial (ctx: CcCtx) =
{ ctx with
Current = ctx.Current |> knownCtxAddLocal varSerial }
Expand All @@ -250,8 +228,6 @@ let private useFun funSerial (ctx: CcCtx) =

/// Called on leave function declaration to store the current known context.
let private saveKnownCtxToFun funSerial (ctx: CcCtx) =
let ctx = ctx |> addKnown funSerial

// Don't update in the second traversal for transformation.
let funKnowns = ctx.FunKnowns

Expand All @@ -273,84 +249,72 @@ let private leaveFunDecl funSerial (baseCtx: CcCtx) (ctx: CcCtx) =
ctx.Current
|> knownCtxLeaveFunDecl baseCtx.Current }

/// Gets a list of captured variable serials for a function
/// including referenced functions.
let private getCapturedSerials funSerial (ctx: CcCtx) =
match ctx.FunKnowns |> mapTryFind funSerial with
| Some knownCtx -> knownCtx |> knownCtxToNonlocalUses
| None -> [], []

/// Gets a list of captured variables for a function.
/// Doesn't include referenced functions.
let private genFunCaps funSerial (ctx: CcCtx): Caps =
let varSerials, _ = ctx |> getCapturedSerials funSerial
let varSerials =
match ctx.FunUpvars |> mapTryFind funSerial with
| Some it -> it |> setToList
| None -> []

// FIXME: List.rev here is just to reduce diff. Remove later.
varSerials
|> List.choose
(fun varSerial ->
match ctx.Vars |> mapTryFind varSerial with
| Some (VarDef (_, AutoSM, ty, loc)) -> Some(varSerial, ty, loc)

| _ -> None)
|> List.rev

/// Extends the set of references to be transitive.
/// E.g. a function `f` uses `g` and `g` uses `h` (and `h` uses etc.),
/// we think `f` also uses `h`.
let private closureRefs (ctx: CcCtx): CcCtx =
let rec doClosureRefs vars funs ccCtx (modified, visited, acc) =
match vars, funs with
| [], [] -> modified, visited, acc

| varSerial :: vars, _ ->
let revisit = acc |> setContains varSerial
let modified = modified || not revisit

let acc =
if revisit then acc else acc |> setAdd varSerial

doClosureRefs vars funs ccCtx (modified, visited, acc)

| _, funSerial :: funs ->
let revisit = visited |> setContains funSerial
let modified = modified || not revisit

let visited =
if revisit then visited else visited |> setAdd funSerial

let otherVars, otherFuns = ccCtx |> getCapturedSerials funSerial

(modified, visited, acc)
|> doClosureRefs otherVars otherFuns ccCtx
|> doClosureRefs vars funs ccCtx

let closureKnownCtx (modified, ccCtx) funSerial (knownCtx: KnownCtx) =
let vars = knownCtx.UseVars
let funs = knownCtx.UseFuns

match (false, funs, vars)
|> doClosureRefs (vars |> setToList) (funs |> setToList) ccCtx with
| true, visited, vars ->
true,
{ ccCtx with
FunKnowns =
ccCtx.FunKnowns
|> mapAdd
funSerial
{ knownCtx with
UseVars = vars
UseFuns = visited } }

| false, _, _ -> modified, ccCtx

let rec closureFuns (modified, ccCtx: CcCtx) =
if not modified then
ccCtx
let mergeUpvars localVars newUpvars (modified, upvars): bool * AssocSet<VarSerial> =
newUpvars
|> setFold
(fun (modified, upvars) varSerial ->
if upvars |> setContains varSerial
|| localVars |> setContains varSerial then
modified, upvars
else
true, upvars |> setAdd varSerial)
(modified, upvars)

let visitFun (totalModified, funUpvarsMap) funSerial (upvars, localVars, funs) =
let modified, upvars =
funs
|> List.fold
(fun (modified, upvars) funSerial ->
let newUpvars, _, _ = funUpvarsMap |> mapFind funSerial

(modified, upvars)
|> mergeUpvars localVars newUpvars)
(false, upvars)

if modified then
true,
funUpvarsMap
|> mapAdd funSerial (upvars, localVars, funs)
else
totalModified, funUpvarsMap

let rec makeTransitive funUpvarsMap =
let modified, funUpvarsMap =
funUpvarsMap
|> mapFold visitFun (false, funUpvarsMap)

if modified then
makeTransitive funUpvarsMap
else
ccCtx.FunKnowns
|> mapFold closureKnownCtx (false, ccCtx)
|> closureFuns
funUpvarsMap

let funUpvars =
ctx.FunKnowns
|> mapMap
(fun (_: FunSerial) (knownCtx: KnownCtx) ->
knownCtxToNonlocalVars knownCtx, knownCtx.Locals, setToList knownCtx.UseFuns)
|> makeTransitive
|> mapMap (fun (_: FunSerial) (upvars: AssocSet<VarSerial>, _: AssocSet<VarSerial>, _: FunSerial list) -> upvars)

closureFuns (true, ctx)
{ ctx with FunUpvars = funUpvars }

/// Applies the changes of function types.
let private updateFunDefs (ctx: CcCtx) =
Expand Down

0 comments on commit 4671606

Please sign in to comment.