Skip to content

Commit

Permalink
split long functions resulting from migrating compiletime values (#921)
Browse files Browse the repository at this point in the history
* split long functions resulting from migrating compiletime values

see #920
  • Loading branch information
peq authored and Frotty committed Jan 14, 2020
1 parent 2dad075 commit c8b3a7e
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import de.peeeq.wurstscript.intermediatelang.interpreter.ILStackFrame;
import de.peeeq.wurstscript.intermediatelang.interpreter.LocalState;
import de.peeeq.wurstscript.intermediatelang.interpreter.ProgramState;
import de.peeeq.wurstscript.intermediatelang.optimizer.FunctionSplitter;
import de.peeeq.wurstscript.jassIm.*;
import de.peeeq.wurstscript.jassinterpreter.TestFailException;
import de.peeeq.wurstscript.jassinterpreter.TestSuccessException;
Expand Down Expand Up @@ -108,6 +109,8 @@ public void run() {
}
runDelayedActions();

partitionCompiletimeStateInitFunction();

} catch (InterpreterException e) {
Element origin = e.getTrace();
sendErrors(origin, e.getMessage(), e);
Expand All @@ -131,6 +134,14 @@ public void run() {

}

private void partitionCompiletimeStateInitFunction() {
if (compiletimeStateInitFunction == null) {
return;
}

FunctionSplitter.splitFunc(translator, compiletimeStateInitFunction);
}

private boolean isUnitTestMode() {
return Optional.ofNullable(imProg)
.map(ImProg::attrTrace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,22 @@ private ExecutableJassFunction searchFunction(String fname) {
}

private ExecutableJassFunction searchNativeJassFunction(String name) {
if (name.equals("ExecuteFunc")) {
return executeFuncNative();
}
ReflectionNativeProvider nf = new ReflectionNativeProvider(this);
ExecutableJassFunction functionPair = nf.getFunctionPair(name);
return functionPair != null ? functionPair : new UnknownJassFunction(name);
}

private ExecutableJassFunction executeFuncNative() {
return (jassInterpreter, arguments) -> {
ILconstString funcName = (ILconstString) arguments[0];
jassInterpreter.executeFunction(funcName.getVal());
return ILconstBool.TRUE;
};
}

public void trace(boolean b) {
trace = b;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class RunArgs {
private RunOption optionMeasureTimes;
private RunOption optionHotStartmap;
private RunOption optionHotReload;
private int functionSplitLimit = 10000;

private RunOption optionBuild;

Expand Down Expand Up @@ -125,6 +126,8 @@ public RunArgs(String... args) {
addOptionWithArg("inputmap", "The next argument should be the input map.", arg -> inputmap = arg);
optionLua = addOption("lua", "Choose Lua as the compilation target.");

addOptionWithArg("functionSplitLimit", "The maximum number of operations in a function before it is split by the function splitter (used for compiletime functions)",
s -> functionSplitLimit = Integer.parseInt(s, 10));

nextArg:
for (int i = 0; i < args.length; i++) {
Expand Down Expand Up @@ -344,4 +347,9 @@ public boolean isLua() {
return optionLua.isSet;
}


public int getFunctionSplitLimit() {
return functionSplitLimit;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public String toString() {

}

private void optimizeFunc(ImFunction func) {
void optimizeFunc(ImFunction func) {
ControlFlowGraph cfg = new ControlFlowGraph(func.getBody());
Map<Node, Knowledge> knowledge = calculateKnowledge(cfg);
rewriteCode(cfg, knowledge);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
package de.peeeq.wurstscript.intermediatelang.optimizer;

import com.google.common.base.Preconditions;
import de.peeeq.wurstscript.attributes.CompileError;
import de.peeeq.wurstscript.jassIm.*;
import de.peeeq.wurstscript.translation.imtranslation.CallType;
import de.peeeq.wurstscript.translation.imtranslation.ImTranslator;
import de.peeeq.wurstscript.translation.imtranslation.UsedVariables;

import java.util.*;
import java.util.stream.Collectors;

/**
* Splits a long function into several smaller functions which are
* executed using TriggerEvaluate
*/
public class FunctionSplitter {

private final int op_limit;
private final ImTranslator tr;
private final ImFunction func;
private final Map<ImFunction, Integer> fuelVisited;

private FunctionSplitter(int op_limit, ImTranslator tr, ImFunction func) {
this.op_limit = op_limit;
this.tr = tr;
this.func = func;
this.fuelVisited = new LinkedHashMap<>();
fuelVisited.put(func, null);
}

public static void splitFunc(ImTranslator tr, ImFunction func) {
new FunctionSplitter(tr.getRunArgs().getFunctionSplitLimit(), tr, func).optimize();

}

private void optimize() {
Preconditions.checkArgument(func.getTypeVariables().isEmpty(), "func must not be generic");
Preconditions.checkArgument(func.getParameters().isEmpty(), "func parameters must be empty");
Preconditions.checkArgument(func.getReturnType() instanceof ImVoid, "func must return void");
// run some basic optimizations first:
func.flatten(tr);
new ConstantAndCopyPropagation().optimizeFunc(func);
new TempMerger().optimizeFunc(func);
new LocalMerger().optimizeFunc(func);
Set<ImVar> usedVars = UsedVariables.calculate(func);
func.getLocals().removeIf(v -> !usedVars.contains(v));
func.flatten(tr);
List<List<ImStmt>> splitResult = split(func.getBody().removeAll());

ImProg prog = tr.getImProg();
// make all local variables global
prog.getGlobals().addAll(func.getLocals().removeAll());

// create helper functions
List<ImFunction> helperFuncs = new ArrayList<>();
for (int i = 0; i < splitResult.size(); i++) {
List<ImStmt> stmts = splitResult.get(i);
ImFunction helperFunc = JassIm.ImFunction(
func.getTrace(),
func.getName() + "_" + i,
JassIm.ImTypeVars(),
JassIm.ImVars(),
JassIm.ImVoid(),
JassIm.ImVars(),
JassIm.ImStmts(stmts),
Collections.emptyList()
);
helperFuncs.add(helperFunc);
}
prog.getFunctions().addAll(helperFuncs);

// call helper functions with ExecuteFunc
for (ImFunction helperFunc : helperFuncs) {
func.getBody().add(JassIm.ImFunctionCall(
func.getTrace(),
helperFunc,
JassIm.ImTypeArguments(),
JassIm.ImExprs(),
false,
CallType.EXECUTE
));
}
}

private List<List<ImStmt>> split(List<ImStmt> body) {
List<List<ImStmt>> result = new ArrayList<>();
int fuel = 0;
for (ImStmt s : body) {
fuel += estimateFuel(s);
if (result.isEmpty() || fuel > op_limit) {
result.add(new ArrayList<>());
fuel = 0;
}
result.get(result.size() - 1).add(s);
}
return result;
}


private int estimateFuel(ImStmt s) {
return s.match(new ImStmt.Matcher<Integer>() {
@Override
public Integer case_ImTypeVarDispatch(ImTypeVarDispatch s) {
return estimateFuel(s.getArguments()) + 100;
}

@Override
public Integer case_ImDealloc(ImDealloc s) {
return 10 + estimateFuel(s.getObj());
}

@Override
public Integer case_ImBoolVal(ImBoolVal s) {
return 1;
}

@Override
public Integer case_ImTypeIdOfClass(ImTypeIdOfClass s) {
return 1;
}

@Override
public Integer case_ImVarAccess(ImVarAccess s) {
return 1;
}

@Override
public Integer case_ImStringVal(ImStringVal s) {
return 1;
}

@Override
public Integer case_ImMethodCall(ImMethodCall s) {
return estimateFuel(s.getArguments())
+ estimateFuelMethod(s.getMethod())
+ 10;
}

@Override
public Integer case_ImRealVal(ImRealVal s) {
return 1;
}

@Override
public Integer case_ImFunctionCall(ImFunctionCall s) {
return estimateFuel(s.getArguments()) + 10 + estimateFuelFunc(s.getFunc());
}

@Override
public Integer case_ImReturn(ImReturn s) {
return 3 + estimateFuelOpt(s.getReturnValue());
}

@Override
public Integer case_ImTupleSelection(ImTupleSelection s) {
return estimateFuel(s.getTupleExpr()) + 10;
}

@Override
public Integer case_ImOperatorCall(ImOperatorCall s) {
return 5 + estimateFuel(s.getArguments());
}

@Override
public Integer case_ImVarArrayAccess(ImVarArrayAccess s) {
return 3 + estimateFuel(s.getIndexes());
}

@Override
public Integer case_ImAlloc(ImAlloc s) {
return 30;
}

@Override
public Integer case_ImIntVal(ImIntVal s) {
return 1;
}

@Override
public Integer case_ImExitwhen(ImExitwhen s) {
return 5 + estimateFuel(s.getCondition());
}

@Override
public Integer case_ImVarargLoop(ImVarargLoop s) {
throw new CompileError(s, "Cannot estimate size of function " + func.getBody() + " as it contains looops.");
}

@Override
public Integer case_ImNull(ImNull s) {
return 1;
}

@Override
public Integer case_ImLoop(ImLoop s) {
throw new CompileError(s, "Cannot estimate size of function " + func.getBody() + " as it contains looops.");
}

@Override
public Integer case_ImMemberAccess(ImMemberAccess s) {
return 1 + estimateFuel(s.getReceiver())
+ estimateFuel(s.getIndexes());
}

@Override
public Integer case_ImGetStackTrace(ImGetStackTrace s) {
return 50;
}

@Override
public Integer case_ImTupleExpr(ImTupleExpr s) {
return 1 + estimateFuel(s.getExprs());
}

@Override
public Integer case_ImTypeIdOfObj(ImTypeIdOfObj s) {
return 1 + estimateFuel(s.getObj());
}

@Override
public Integer case_ImSet(ImSet s) {
return 3 + estimateFuel(s.getLeft()) + estimateFuel(s.getRight());
}

@Override
public Integer case_ImStatementExpr(ImStatementExpr s) {
return 1 + estimateFuel(s.getStatements()) + estimateFuel(s.getExpr());
}

@Override
public Integer case_ImCompiletimeExpr(ImCompiletimeExpr s) {
return 1;
}

@Override
public Integer case_ImIf(ImIf s) {
return 1 + estimateFuel(s.getCondition()) + Math.max(estimateFuel(s.getThenBlock()), estimateFuel(s.getElseBlock()));
}

@Override
public Integer case_ImCast(ImCast s) {
return estimateFuel(s.getExpr());
}

@Override
public Integer case_ImFuncRef(ImFuncRef s) {
return 1;
}

@Override
public Integer case_ImInstanceof(ImInstanceof s) {
return 1 + estimateFuel(s.getObj()) + 10 * tr.getImProg().getClasses().size();
}
});
}

private int estimateFuelMethod(ImMethod method) {
return Math.max(
estimateFuelFunc(method.getImplementation()),
method.getSubMethods().stream()
.mapToInt(m -> estimateFuelMethod(method))
.sum());
}

private int estimateFuelFunc(ImFunction f) {
if (f.isNative()) {
return 10;
}
if (fuelVisited.containsKey(f)) {
Integer v = fuelVisited.get(f);
if (v == null) {
throw new CompileError(func, "Cannot split recursive method " + func.getName() + " calling funcs: " +
fuelVisited.entrySet().stream()
.filter(e -> e.getValue() == null)
.map(e -> e.getKey().getName())
.collect(Collectors.joining(", ")));
}
return v;
} else {
// mark f as being calculated
fuelVisited.put(f, null);
int v = estimateFuel(f.getBody());
fuelVisited.put(f, v);
return v;
}
}

private int estimateFuelOpt(ImExprOpt returnValue) {
if (returnValue instanceof ImExpr) {
return estimateFuel((ImExpr) returnValue);
}
return 0;
}

private int estimateFuel(List<? extends ImStmt> stmts) {
return stmts.stream().mapToInt(this::estimateFuel).sum();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public String getName() {
return "Local variables merged";
}

private void optimizeFunc(ImFunction func) {
void optimizeFunc(ImFunction func) {
Map<ImStmt, Set<ImVar>> livenessInfo = calculateLiveness(func);
eliminateDeadCode(livenessInfo);
mergeLocals(livenessInfo, func);
Expand Down
Loading

0 comments on commit c8b3a7e

Please sign in to comment.