Skip to content

Commit

Permalink
Fix normalization of devirtualized abstract methods
Browse files Browse the repository at this point in the history
  • Loading branch information
titzer committed May 4, 2024
1 parent 2079f52 commit d91bf69
Show file tree
Hide file tree
Showing 19 changed files with 456 additions and 22 deletions.
2 changes: 1 addition & 1 deletion aeneas/src/ir/Normalization.v3
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ class ReachabilityNormalizer(config: NormalizerConfig, ra: ReachabilityAnalyzer)
private def getClassAllocIr(vn: VariantNorm) -> IrMethod {
var context = SsaContext.new(ra.compiler, ra.prog);
var meth = IrMethod.new(vn.oldType, null, Function.siga(vn.vecO, vn.newType));

var params = Array<SsaParam>.new(meth.sig.paramTypes.length + 1);
var inputs = Array<SsaInstr>.new(meth.sig.paramTypes.length);
params[0] = SsaParam.new(0, vn.oldType);
Expand Down
43 changes: 27 additions & 16 deletions aeneas/src/ir/SsaNormalizer.v3
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class SsaRaNormalizer extends SsaRebuilder {
}

def build(newMethod: IrMethod) {
// trace = CLOptions.TRACE_NORM.get().matches(newMethod.source);
// if (trace) Terminal.put1("normalizing %q\n", newMethod.renderLong);

newMethod.ssa = genGraph();
context.method = newMethod;
context.graph = newMethod.ssa;
Expand Down Expand Up @@ -156,6 +159,7 @@ class SsaRaNormalizer extends SsaRebuilder {
return vals;
}
def genApplyOp(i_old: SsaApplyOp) {
// if (trace) Terminal.put2("genApplyOp @%d %s\n", i_old.uid, i_old.op.opcode.name);
if (i_old.useList == null && i_old.facts.O_PURE) return; // remove dead code
curBlock.at(i_old.source);
var orig = i_old.op, op = i_old.op, args = i_old.inputs;
Expand Down Expand Up @@ -264,16 +268,22 @@ class SsaRaNormalizer extends SsaRebuilder {
// devirtualize methods that are not overridden
var t = extractVirtualRef(orig, method), funcNorm = t.0, m = t.1;
var ai_new = normArgs(funcNorm, genRefs(i_old.inputs));
var i_new: SsaInstr;
var i: SsaInstr;
if (t.2) { // still a virtual dispatch
i_new = normCall(i_old, funcNorm, V3Op.newCallClassSelector(m), ai_new);
i = normCall(i_old, funcNorm, V3Op.newCallClassSelector(m), ai_new);
} else {
// devirtualized to call abstract method => no objects instantiated of that type
if (m.member.facts.M_ABSTRACT) return map1(i_old, newGraph.nullConst(m.getReturnType()));
if (m.member.facts.M_ABSTRACT) {
var sig = funcNorm.sig(), rt = sig.returnTypes;
var rv = Vector<SsaInstr>.new();
for (i < rt.length) rv.put(newGraph.nullConst(rt[i]));
for (i < funcNorm.ovfReturnTypes.length) rv.put(newGraph.nullConst(funcNorm.ovfReturnTypes[i]));
return mapN(i_old, rv.extract());
}
var newOp = V3Op.newCallClassMethod(m);
i_new = normCall(i_old, funcNorm, newOp, ai_new);
i = normCall(i_old, funcNorm, newOp, ai_new);
}
if (ai_new[0].facts.V_NON_ZERO) i_new.setFact(Fact.O_NO_NULL_CHECK);
if (ai_new[0].facts.V_NON_ZERO) i.setFact(Fact.O_NO_NULL_CHECK);
}
CallVariantVirtual(method) => {
// devirtualize methods that are not overridden
Expand Down Expand Up @@ -463,7 +473,7 @@ class SsaRaNormalizer extends SsaRebuilder {
}
RefLayoutSetRepeatedField(offset, scale, max) => {
var rn = normTypeArg(op, 0), fn = normTypeArg(op, 1);
var ai_new = genRefs(args);
var ai_new = genRefs(args);
var array = ai_new[0], start = ai_new[1], index = ai_new[2], val = ai_new[3];
if (context.compiler.boundsCheck(i_old.facts)) {
var oob = curBlock.opBoolNot(curBlock.opIntULt(norm.config.ArrayLengthType, norm.config.ArrayLengthType,
Expand Down Expand Up @@ -654,7 +664,8 @@ class SsaRaNormalizer extends SsaRebuilder {
}
mapN(i_old, rvals.extract());
} else {
mapNorm(i_old, call, normType(i_old.op.sig.returnType()));
var tn = normType(i_old.op.sig.returnType());
mapNorm(i_old, call, tn);
}
return call;
}
Expand Down Expand Up @@ -706,7 +717,7 @@ class SsaRaNormalizer extends SsaRebuilder {
def normEqual(i_old: SsaApplyOp, tn: TypeNorm, refs: Array<SsaInstr>) -> SsaInstr {
if (tn.size == 0) return newGraph.trueConst();
if (tn.size == 1) return normEqual1(i_old, tn.newType, refs[0], refs[1]);

var expr: SsaInstr;
for (i < tn.size) {
var cmp = normEqual1(i_old, tn.sub[i], refs[i], refs[i + tn.size]);
Expand Down Expand Up @@ -855,7 +866,7 @@ class SsaRaNormalizer extends SsaRebuilder {
} else {
result.put(newGraph.zeroConst());
}

if (array.facts.V_NON_ZERO) {
len = curBlock.opArrayGetLength(arrayType, array);
} else {
Expand Down Expand Up @@ -1002,13 +1013,13 @@ class SsaRaNormalizer extends SsaRebuilder {
var result = Array<SsaInstr>.new(rangeNorm.size);
for (i < rangeNorm.startIndex()) result[i] = ai_new[i];
if (startType.width > 32) start = curBlock.opIntViewI0(startType, rangeLengthType, start);

if (norm.config.NormalizeRange && rangeNorm.arrayNorm.size <= 1) {
result[rangeNorm.startIndex()] = curBlock.opRangeStartPlusIndex(rangeNorm.oldType, rangeLengthType, i_old.facts, rangeStart, start);
} else {
result[rangeNorm.startIndex()] = curBlock.opIntAdd(start, rangeStart);
}

if (endType.width > 32) end = curBlock.opIntViewI0(endType, rangeLengthType, end);
result[rangeNorm.lengthIndex()] = curBlock.opIntSub(end, start);
mapN(i_old, result);
Expand Down Expand Up @@ -1036,10 +1047,10 @@ class SsaRaNormalizer extends SsaRebuilder {
if (startType.width > 32) start = curBlock.opIntViewI0(startType, rangeLengthType, start);
if (lengthType.width > 32) length = curBlock.opIntViewI0(lengthType, rangeLengthType, length);
}

var result = Array<SsaInstr>.new(rangeNorm.size);
for (i < rangeNorm.startIndex()) result[i] = ai_new[i];

if (norm.config.NormalizeRange && rangeNorm.arrayNorm.size <= 1) {
result[rangeNorm.startIndex()] = curBlock.opRangeStartPlusIndex(rangeNorm.oldType, rangeLengthType, i_old.facts, rangeStart, start);
} else {
Expand Down Expand Up @@ -1069,7 +1080,7 @@ class SsaRaNormalizer extends SsaRebuilder {
} else if (width == 0) {
return map0(i_old);
}

var vals = Array<SsaInstr>.new(width);
if (rangeNorm.arrayNorm.isMixed()) {
var array = ai_new[0];
Expand Down Expand Up @@ -1160,7 +1171,7 @@ class SsaRaNormalizer extends SsaRebuilder {
var oob = curBlock.pure(Int.getType(false, 32).opLtEq(), [length, index]);
curBlock.opConditionalThrow(V3Exception.BoundsCheck, oob);
}

var arrayNorm = rangeNorm.arrayNorm;
var facts = i_old.facts;

Expand Down Expand Up @@ -1384,7 +1395,7 @@ class SsaRaNormalizer extends SsaRebuilder {

var result = Array<SsaInstr>.new(vn.size);
for (i < result.length) result[i] = newGraph.nullConst(vn.at(i));

if (vn.hasExplicitTag()) {
var tagIdx = vn.tag.indexes[0];
if (IntRepType.?(vn.at(tagIdx))) result[tagIdx] = genSetInterval(result[tagIdx], newGraph.intConst(vn.tagValue), vn.tag.intervals[0], vn.tag.tn.newType, IntRepType.!(vn.at(tagIdx)));
Expand Down
2 changes: 2 additions & 0 deletions aeneas/src/main/CLOptions.v3
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ component CLOptions {
"Print the results of reachability analysis.");
def PRINT_SSA = debugOpt.newMatcherOption("print-ssa",
"Print internal SSA code as it is generated.");
def TRACE_NORM = debugOpt.newMatcherOption("trace-norm",
"Trace normalization of SSA for the given method(s).");
def PRINT_OPT = debugOpt.newMatcherOption("print-opt",
"Print optimizations as they are performed.");
def PRINT_SSA_STATS = debugOpt.newMatcherOption("print-ssa-stats",
Expand Down
2 changes: 1 addition & 1 deletion aeneas/src/main/Version.v3
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

// Updated by VCS scripts. DO NOT EDIT.
component Version {
def version: string = "III-7.1716";
def version: string = "III-7.1717";
var buildData: string;
}
14 changes: 10 additions & 4 deletions aeneas/src/ssa/SsaRebuilder.v3
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class SsaRebuilder(context: SsaContext) {
var curBlock: SsaBuilder;
var edgeMapMap: PartialMap<SsaBlock, Array<int>>;
var blocks: int;
// var trace: bool = false;

// Stack for processing phis and blocks. Order doesn't really matter as long as a block's
// dominators are processed before the block. Using a queue yields breadth-first order.
Expand Down Expand Up @@ -62,15 +63,15 @@ class SsaRebuilder(context: SsaContext) {
private def finishPhi(i_old: SsaPhi) {
// XXX: if only one predecessor, replace phi with its (one) input
var b_old = i_old.block, b_new = mapBlockStart(b_old);
// Terminal.put3("old #%d -> new #%d |%d|\n", b_old.uid, b_new.uid, b_new.preds.length);
// if (trace) Terminal.put3("finishPhi @%d (in block #%d, to block #%d)\n", i_old.uid, b_old.uid, b_new.uid);
var edgeMap = getEdgeMap(i_old, b_old, b_new);
if (instrMap.has1(i_old)) {
var i_new = SsaPhi.!(instrMap[i_old]);
// phi was mapped one-to-one; map the new inputs
var ai_new_inputs = Array<SsaInstr>.new(b_new.preds.length);
var facts = Facts.NONE;
for (j < ai_new_inputs.length) {
// Terminal.put3("@%d[%d] -> i_new[%d]\n", i_old.uid, edgeMap[j], j);
// if (trace) Terminal.put3("@%d[%d] -> i_new[%d]\n", i_old.uid, edgeMap[j], j);
var i_input = genRef1(i_old.inputs[edgeMap[j]]);
ai_new_inputs[j] = i_input;
if (j == 0) facts = i_input.facts;
Expand All @@ -85,11 +86,16 @@ class SsaRebuilder(context: SsaContext) {
var ai_new = instrMap.getN(i_old);
for (w < ai_new.length) {
var i_new_phi = SsaPhi.!(ai_new[w]);
// if (trace) Terminal.put3(" @%d[%d] -> @%d\n", i_old.uid, w, i_new_phi.uid);
var ai_new_inputs = Array<SsaInstr>.new(b_new.preds.length);
for (j < ai_new_inputs.length) {
// XXX: interchange these loops for better performance
var index = i_old.inputs[edgeMap[j]];
ai_new_inputs[j] = genRefs([index])[w];
var e = edgeMap[j];
// if (trace) Terminal.put2(" edgeMap[%d] = %d\n", j, e);
var old_input = i_old.inputs[e];
var new_inputs = genRefs([old_input]);
// if (trace) Terminal.put2(" old = @%d -> |%d|\n", old_input.dest.uid, new_inputs.length);
ai_new_inputs[j] = new_inputs[w];
}
i_new_phi.setInputs(ai_new_inputs);
b_new.prepend(i_new_phi);
Expand Down
19 changes: 19 additions & 0 deletions test/core/emap05.v3
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//@execute 0=!NullCheckException
var f: Y;
def main(a: int) -> int {
return f.m(a).0;
}
def foo() {
if (f == null) f = Z.new();
}

class X {
def m(a: int) -> (int, int);
}
class Y extends X {
}
class Z extends Y {
def m(a: int) -> (int, int) {
return (a + 42, a + 53);
}
}
21 changes: 21 additions & 0 deletions test/core/emap06.v3
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//@execute -1=0; 0=!NullCheckException
var f: B;
def main(a: int) -> int {
var x: int;
if (a >= 0) x = f.m(a).0;
return x;
}
def foo() {
if (f == null) f = C.new();
}

class A {
def m(a: int) -> (int, int);
}
class B extends A {
}
class C extends B {
def m(a: int) -> (int, int) {
return (a + 42, a + 53);
}
}
21 changes: 21 additions & 0 deletions test/core/emap07.v3
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//@execute -1=0; 0=!NullCheckException
var f: B;
def main(a: int) -> int {
var x: (int, int);
if (a >= 0) x = f.m(a);
return x.0;
}
def foo() {
if (f == null) f = C.new();
}

class A {
def m(a: int) -> (int, int);
}
class B extends A {
}
class C extends B {
def m(a: int) -> (int, int) {
return (a + 42, a + 53);
}
}
18 changes: 18 additions & 0 deletions test/core/emap08.v3
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//@execute -1=0; 0=!NullCheckException
var f: B;
def main(a: int) -> int {
var x: (int, int);
if (a >= 0) x = f.m(a);
return x.0;
}

class A {
def m(a: int) -> (int, int);
}
class B extends A {
}
class C extends B {
def m(a: int) -> (int, int) {
return (a + 42, a + 53);
}
}
16 changes: 16 additions & 0 deletions test/core/emap09.v3
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//@execute -1=8; 0=!NullCheckException
var f: B;
def main(a: int) -> int {
var x = int.!<int>;
if (a >= 0) x = f.m(a);
return x(8);
}

class A {
def m(a: int) -> (int -> int);
}
class B extends A {
}
class C extends B {
def m(a: int) -> (int -> int) { return int.~; }
}
72 changes: 72 additions & 0 deletions test/variants/ub_frameloc00.v3
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//@execute 0=11; 1=11

class WasmFunction(id: int) { }
class HostFunction(id: int) { }
class HostFrame {
def caller() -> FrameLoc {
return FrameLoc.None;
}
}
class V3Frame {
var prev: V3Frame;
var func: WasmFunction;
var pc: int;
}
class FrameAccessor {
def caller() -> FrameLoc {
return FrameLoc.None;
}
}

type TargetFrame(f: V3Frame) #unboxed {
def getFrameAccessor() -> FrameAccessor {
return FrameAccessor.new();
}
}

type FrameLoc #unboxed {
case None;
case Wasm(func: WasmFunction, pc: int, frame: TargetFrame);
case Host(func: HostFunction, frame: HostFrame);
}

def FUNC_19 = WasmFunction.new(19);
def FUNC_21 = WasmFunction.new(21);
def HOST_22 = HostFunction.new(22);

def frame2 = V3Frame.new();

def frames = [
FrameLoc.Wasm(FUNC_19, 11, TargetFrame(V3Frame.new())),
FrameLoc.Wasm(FUNC_21, 13, TargetFrame(V3Frame.new())),
FrameLoc.Host(HOST_22, HostFrame.new())
];

def main(a: int) -> int {
var g = FrameLoc.Wasm.!(frames[a]);
renderStack(g.frame.getFrameAccessor());
return 11;
}

def renderStack(accessor: FrameAccessor) {
var depth = 0;
var caller = accessor.caller();
while (true) {
match (caller) {
None => break;
Wasm(func, pc, frame) => {
// OUT.put1("%d: ", depth);
// func.render(OUT);
accessor = frame.getFrameAccessor();
caller = accessor.caller();
}
Host(func, frame) => {
// OUT.put1("%d: ", depth);
// func.render(OUT);
caller = frame.caller();
}
}
// OUT.outln();
depth++;
}
}

0 comments on commit d91bf69

Please sign in to comment.