Skip to content

Commit

Permalink
Support for persisting objects in compiletime expressions (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
peq committed Mar 21, 2019
2 parents 54ceb4e + adeb192 commit eb06e6a
Show file tree
Hide file tree
Showing 12 changed files with 403 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
import de.peeeq.wurstscript.jassinterpreter.TestFailException;
import de.peeeq.wurstscript.jassinterpreter.TestSuccessException;
import de.peeeq.wurstscript.parser.WPos;
import de.peeeq.wurstscript.translation.imtranslation.CallType;
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlag;
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlagCompiletime;
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlagEnum;
import de.peeeq.wurstscript.translation.imtranslation.*;
import de.peeeq.wurstscript.types.TypesHelper;
import de.peeeq.wurstscript.utils.Pair;
import de.peeeq.wurstscript.utils.Utils;
import org.eclipse.lsp4j.jsonrpc.messages.Either;
Expand All @@ -47,8 +45,9 @@ public class CompiletimeFunctionRunner {
private final List<ImFunction> successTests = Lists.newArrayList();
private final Map<ImFunction, Pair<de.peeeq.wurstscript.jassIm.Element, String>> failTests = Maps.newLinkedHashMap();
private final ProgramStateIO globalState;
private final ImTranslator translator;
private boolean injectObjects;
private final List<Runnable> delayedActions = new ArrayList<>();
private final Deque<Runnable> delayedActions = new ArrayDeque<>();

public ILInterpreter getInterpreter() {
return interpreter;
Expand Down Expand Up @@ -77,8 +76,9 @@ public boolean matches(ImFunction f) {
}


public CompiletimeFunctionRunner(ImProg imProg, File mapFile, MpqEditor mpqEditor, WurstGui gui, FunctionFlagToRun flag) {
public CompiletimeFunctionRunner(ImTranslator tr, ImProg imProg, File mapFile, MpqEditor mpqEditor, WurstGui gui, FunctionFlagToRun flag) {
Preconditions.checkNotNull(imProg);
this.translator = tr;
this.imProg = imProg;
globalState = new ProgramStateIO(mapFile, mpqEditor, gui, imProg, true);
this.interpreter = new ILInterpreter(imProg, gui, mapFile, globalState);
Expand Down Expand Up @@ -152,8 +152,8 @@ private void sendErrors(Element origin, String msg) {
* Run actions that must be run after all other code
*/
private void runDelayedActions() {
for (Runnable delayedAction : delayedActions) {
delayedAction.run();
while (!delayedActions.isEmpty()) {
delayedActions.removeFirst().run();
}
}

Expand Down Expand Up @@ -214,7 +214,7 @@ private void executeCompiletimeExpr(ImCompiletimeExpr cte) {
globalState.pushStackframe(cte, cte.attrTrace().attrErrorPos());
LocalState localState = new LocalState();
ILconst value = cte.evaluate(globalState, localState);
ImExpr newExpr = constantToExpr(cte, value);
ImExpr newExpr = constantToExpr(cte.getTrace(), value);
cte.replaceBy(newExpr);
} catch (InterpreterException e) {
String msg = ILInterpreter.buildStacktrace(globalState, e);
Expand All @@ -224,8 +224,66 @@ private void executeCompiletimeExpr(ImCompiletimeExpr cte) {
}
}

private ImExpr constantToExpr(ImCompiletimeExpr cte, ILconst value) {
Element trace = cte.attrTrace();

private GetAForB<ILconstObject, ImVar> globalForObject = new GetAForB<ILconstObject, ImVar>() {
@Override
public ImVar initFor(ILconstObject obj) {


ImVar res = JassIm.ImVar(obj.getTrace(), obj.getType(), obj.getType() + "_compiletime", false);
imProg.getGlobals().add(res);
ImAlloc alloc = JassIm.ImAlloc(obj.getTrace(), obj.getType());
addCompiletimeStateInitAlloc(alloc.getTrace(), res, alloc);


Element trace = obj.getTrace();

delayedActions.add(() -> {
for (Map.Entry<ImVar, Map<List<Integer>, ILconst>> entry : obj.getAttributes().rowMap().entrySet()) {
ImVar var = entry.getKey();
Map<List<Integer>, ILconst> value1 = entry.getValue();
for (Map.Entry<List<Integer>, ILconst> entry2 : value1.entrySet()) {
List<Integer> indexes = entry2.getKey();
ILconst attrValue = entry2.getValue();
ImExprs indexesT = indexes.stream()
.map(i -> constantToExpr(trace, ILconstInt.create(i)))
.collect(Collectors.toCollection(JassIm::ImExprs));
addCompiletimeStateInit(JassIm.ImSet(trace, JassIm.ImMemberAccess(trace, JassIm.ImVarAccess(res), JassIm.ImTypeArguments(), var, indexesT), constantToExpr(trace, attrValue)));
}
}
});

return res;
}
};

private GetAForB<IlConstHandle, ImVar> globalForHandle = new GetAForB<IlConstHandle, ImVar>() {
@Override
public ImVar initFor(IlConstHandle a) {

Element trace = imProg.getTrace();

ImExpr init;

Object obj = a.getObj();
if (obj instanceof ArrayListMultimap) {
@SuppressWarnings("unchecked")
ArrayListMultimap<HashtableProvider.KeyPair, Object> map = (ArrayListMultimap<HashtableProvider.KeyPair, Object>) obj;
ImType type = TypesHelper.imHashTable();
ImVar res = JassIm.ImVar(trace, type, type + "_compiletime", false);
imProg.getGlobals().add(res);

init = constantToExprHashtable(trace, res, a, map);
addCompiletimeStateInitAlloc(trace, res, init);

return res;
} else {
throw new RuntimeException("Handle value " + obj + " (" + obj.getClass() + ") can not be persistet at compiletime");
}
}
};

private ImExpr constantToExpr(Element trace, ILconst value) {
if (value instanceof ILconstBool) {
return JassIm.ImBoolVal(((ILconstBool) value).getVal());
} else if (value instanceof ILconstInt) {
Expand All @@ -236,40 +294,70 @@ private ImExpr constantToExpr(ImCompiletimeExpr cte, ILconst value) {
return JassIm.ImTupleExpr(
JassIm.ImExprs(
((ILconstTuple) value).values().stream()
.map(e -> constantToExpr(cte, e))
.map(e -> constantToExpr(trace, e))
.collect(Collectors.toList())
)
);
} else if (value instanceof IlConstHandle) {
IlConstHandle h = (IlConstHandle) value;
Object obj = h.getObj();
if (obj instanceof ArrayListMultimap) {
// a hashtable
@SuppressWarnings("unchecked")
ArrayListMultimap<HashtableProvider.KeyPair, Object> map = (ArrayListMultimap<HashtableProvider.KeyPair, Object>) obj;
return constantToExprHashtable(cte, trace, map);
}
ImVar hVar = globalForHandle.getFor(h);
return JassIm.ImVarAccess(hVar);
} else if (value instanceof ILconstObject) {
ILconstObject obj = (ILconstObject) value;
ImVar v = globalForObject.getFor(obj);
return JassIm.ImVarAccess(v);
}
throw new InterpreterException(trace, "Compiletime expression returned unsupported value " + value);

}

private ImFunction compiletimeStateInitFunction = null;

private ImFunction getCompiletimeStateInitFunction() {
ImFunction res = this.compiletimeStateInitFunction;
if (res == null) {
Element trace = imProg.getTrace();
res = JassIm.ImFunction(trace, "initCompiletimeState", JassIm.ImTypeVars(), JassIm.ImVars(), JassIm.ImVoid(), JassIm.ImVars(), JassIm.ImStmts(), Collections.emptyList());
imProg.getFunctions().add(res);
compiletimeStateInitFunction = res;
ImFunction mainFunc = translator.getMainFunc();
ImFunction globalInitFunc = translator.getGlobalInitFunc();
Preconditions.checkNotNull(mainFunc);
ListIterator<ImStmt> iterator = mainFunc.getBody().listIterator();
ImFunctionCall call = JassIm.ImFunctionCall(trace, res, JassIm.ImTypeArguments(), JassIm.ImExprs(), true, CallType.NORMAL);
while (iterator.hasNext()) {
ImStmt stmt = iterator.next();
if (stmt instanceof ImFunctionCall) {
ImFunctionCall fc = (ImFunctionCall) stmt;
if (fc.getFunc() == globalInitFunc) {
// call initCompiletimeState right after globalInitFunc
iterator.add(call);
return res;
}
}
}
iterator.add(call);
}
return res;
}

// insert at the beginning
private void addCompiletimeStateInitAlloc(Element trace, ImVar v, ImExpr init) {
imProg.getGlobalInits().put(v, Collections.singletonList(init));
getCompiletimeStateInitFunction().getBody().add(0, JassIm.ImSet(trace, JassIm.ImVarAccess(v), init.copy()));
}

// insert at the end
private void addCompiletimeStateInit(ImStmt stmt) {
getCompiletimeStateInitFunction().getBody().add(stmt);
}

/**
* Stores a hashtable value in a compiletime expression
* by generating the respective native calls
*/
private ImExpr constantToExprHashtable(ImCompiletimeExpr cte, Element trace, ArrayListMultimap<HashtableProvider.KeyPair, Object> map) {
ImFunction f = cte.getNearestFunc();
ImVar htVar = JassIm.ImVar(trace, cte.attrTyp(), "ht", false);
f.getLocals().add(htVar);


private ImExpr constantToExprHashtable(Element trace, ImVar htVar, IlConstHandle handle, ArrayListMultimap<HashtableProvider.KeyPair, Object> map) {
WPos errorPos = trace.attrErrorPos();
ImFunction initHashtable = findNative("InitHashtable", errorPos);
ImStmts stmts = JassIm.ImStmts(
JassIm.ImSet(trace, JassIm.ImVarAccess(htVar), JassIm.ImFunctionCall(trace, initHashtable, JassIm.ImTypeArguments(), JassIm.ImExprs(), false, CallType.NORMAL))
);

// we have to collect all values after all compiletime functions have run, so use delayedActions
delayedActions.add(() -> {
for (Map.Entry<HashtableProvider.KeyPair, Object> entry : map.entries()) {
Expand All @@ -278,7 +366,7 @@ private ImExpr constantToExprHashtable(ImCompiletimeExpr cte, Element trace, Arr
if (v instanceof ILconstInt) {
ILconstInt iv = (ILconstInt) v;
ImFunction SaveInteger = findNative("SaveInteger", errorPos);
stmts.add(JassIm.ImFunctionCall(trace, SaveInteger, JassIm.ImTypeArguments(), JassIm.ImExprs(
addCompiletimeStateInit(JassIm.ImFunctionCall(trace, SaveInteger, JassIm.ImTypeArguments(), JassIm.ImExprs(
JassIm.ImVarAccess(htVar),
JassIm.ImIntVal(key.getParentkey()),
JassIm.ImIntVal(key.getChildkey()),
Expand All @@ -287,7 +375,7 @@ private ImExpr constantToExprHashtable(ImCompiletimeExpr cte, Element trace, Arr
} else if (v instanceof ILconstReal) {
ILconstReal iv = (ILconstReal) v;
ImFunction SaveReal = findNative("SaveReal", errorPos);
stmts.add(JassIm.ImFunctionCall(trace, SaveReal, JassIm.ImTypeArguments(), JassIm.ImExprs(
addCompiletimeStateInit(JassIm.ImFunctionCall(trace, SaveReal, JassIm.ImTypeArguments(), JassIm.ImExprs(
JassIm.ImVarAccess(htVar),
JassIm.ImIntVal(key.getParentkey()),
JassIm.ImIntVal(key.getChildkey()),
Expand All @@ -296,7 +384,7 @@ private ImExpr constantToExprHashtable(ImCompiletimeExpr cte, Element trace, Arr
} else if (v instanceof ILconstString) {
ILconstString iv = (ILconstString) v;
ImFunction SaveStr = findNative("SaveStr", errorPos);
stmts.add(JassIm.ImFunctionCall(trace, SaveStr, JassIm.ImTypeArguments(), JassIm.ImExprs(
addCompiletimeStateInit(JassIm.ImFunctionCall(trace, SaveStr, JassIm.ImTypeArguments(), JassIm.ImExprs(
JassIm.ImVarAccess(htVar),
JassIm.ImIntVal(key.getParentkey()),
JassIm.ImIntVal(key.getChildkey()),
Expand All @@ -305,7 +393,7 @@ private ImExpr constantToExprHashtable(ImCompiletimeExpr cte, Element trace, Arr
} else if (v instanceof ILconstBool) {
ILconstBool iv = (ILconstBool) v;
ImFunction SaveBoolean = findNative("SaveBoolean", errorPos);
stmts.add(JassIm.ImFunctionCall(trace, SaveBoolean, JassIm.ImTypeArguments(), JassIm.ImExprs(
addCompiletimeStateInit(JassIm.ImFunctionCall(trace, SaveBoolean, JassIm.ImTypeArguments(), JassIm.ImExprs(
JassIm.ImVarAccess(htVar),
JassIm.ImIntVal(key.getParentkey()),
JassIm.ImIntVal(key.getChildkey()),
Expand All @@ -314,16 +402,12 @@ private ImExpr constantToExprHashtable(ImCompiletimeExpr cte, Element trace, Arr
} else {
throw new CompileError(errorPos, "Unsupported value stored in HashMap: " + v + " // " + v.getClass().getSimpleName());
}


}
});

// we already return the expr and fill out stmts in delayedActions (see above)
return JassIm.ImStatementExpr(
stmts,
JassIm.ImVarAccess(htVar)
);
ImFunction initHashtable = findNative("InitHashtable", errorPos);
return JassIm.ImFunctionCall(trace, initHashtable, JassIm.ImTypeArguments(), JassIm.ImExprs(), false, CallType.NORMAL);
}

@NotNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public void runCompiletime() {
// compile & inject object-editor data
// TODO run optimizations later?
gui.sendProgress("Running compiletime functions");
CompiletimeFunctionRunner ctr = new CompiletimeFunctionRunner(getImProg(), getMapFile(), getMapfileMpqEditor(), gui,
CompiletimeFunctionRunner ctr = new CompiletimeFunctionRunner(imTranslator, getImProg(), getMapFile(), getMapfileMpqEditor(), gui,
CompiletimeFunctions);
ctr.setInjectObjects(runArgs.isInjectObjects());
ctr.setOutputStream(new PrintStream(System.err));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public int getTotalTests() {
public TestResult runTests(ImProg imProg, @Nullable FuncDef funcToTest, @Nullable CompilationUnit cu) {
WurstGui gui = new TestGui();

CompiletimeFunctionRunner cfr = new CompiletimeFunctionRunner(imProg, null, null, gui, CompiletimeFunctions);
CompiletimeFunctionRunner cfr = new CompiletimeFunctionRunner(null, imProg, null, null, gui, CompiletimeFunctions);
ILInterpreter interpreter = cfr.getInterpreter();
ProgramState globalState = cfr.getGlobalState();
if (globalState == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package de.peeeq.wurstscript.intermediatelang;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import de.peeeq.wurstscript.ast.Element;
import de.peeeq.wurstscript.jassIm.ImClass;
import de.peeeq.wurstscript.jassIm.ImClassType;
import de.peeeq.wurstscript.jassIm.ImVar;

import java.util.*;
import java.util.stream.Collectors;

public class ILconstObject extends ILconstAbstract {
private final ImClassType classType;
private final int objectId;
private final Table<ImVar, List<Integer>, ILconst> attributes = HashBasedTable.create();
private boolean destroyed = false;
private final Element trace;

public ILconstObject(ImClassType classType, int objectId, Element trace) {
this.classType = classType;
this.objectId = objectId;
this.trace = trace;
}

public int getObjectId() {
return objectId;
}

@Override
public String print() {
return classType + "_" + hashCode();
}



@Override
public boolean isEqualTo(ILconst other) {
return other == this;
}

public void set(ImVar attr, List<Integer> indexes, ILconst value) {
attributes.put(attr, indexes, value);
}

public Optional<ILconst> get(ImVar attr, List<Integer> indexes) {
return Optional.ofNullable(attributes.get(attr, indexes));
}


public boolean isDestroyed() {
return destroyed;
}

public void destroy() {
destroyed = true;
}

public ImClass getImClass() {
return classType.getClassDef();
}

public Element getTrace() {
return trace;
}

public ImClassType getType() {
return classType;
}

@Override
public int hashCode() {
return objectId;
}

public Table<ImVar, List<Integer>, ILconst> getAttributes() {
return attributes;
}
}
Loading

0 comments on commit eb06e6a

Please sign in to comment.