Skip to content

Commit

Permalink
add overflow checks for add.ovf on ints/uints (dotnet#8089)
Browse files Browse the repository at this point in the history
  • Loading branch information
yowl committed May 1, 2020
1 parent 20fad5b commit e4131da
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 9 deletions.
101 changes: 93 additions & 8 deletions src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs
Expand Up @@ -735,7 +735,7 @@ private void StartImportingInstruction()
debugMetadata = new DebugMetadata(fileMetadata, compileUnitMetadata);
_compilation.DebugMetadataMap[fileName] = debugMetadata;
}

if (_debugFunction.Handle == IntPtr.Zero)
{
LLVMMetadataRef functionMetaType = _compilation.DIBuilder.CreateSubroutineType(debugMetadata.File, ReadOnlySpan<LLVMMetadataRef>.Empty, LLVMDIFlags.LLVMDIFlagZero);
Expand Down Expand Up @@ -2205,6 +2205,11 @@ private static LLVMValueRef BuildConstInt64(long number)
return LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, (ulong)number, false);
}

private static LLVMValueRef BuildConstUInt64(ulong number)
{
return LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, number, false);
}

private LLVMValueRef GetEETypeForTypeDesc(TypeDesc target, bool constructed)
{
var eeTypePointer = GetEETypePointerForTypeDesc(target, constructed);
Expand Down Expand Up @@ -2392,7 +2397,7 @@ private LLVMBasicBlockRef HandleCall(MethodDesc callee, MethodSignature signatur
var canonConstrainedType = constrainedType;
if (constrainedType.IsRuntimeDeterminedSubtype)
canonConstrainedType = constrainedType.ConvertToCanonForm(CanonicalFormKind.Specific);

bool forceUseRuntimeLookup;
var constrainedClosestDefType = canonConstrainedType.GetClosestDefType();
MethodDesc directMethod = constrainedClosestDefType.TryResolveConstraintMethodApprox(callee.OwningType, callee, out forceUseRuntimeLookup);
Expand Down Expand Up @@ -2605,7 +2610,7 @@ private LLVMBasicBlockRef HandleCall(MethodDesc callee, MethodSignature signatur
var fatBranch = _currentFunclet.AppendBasicBlock("fat");
var endifBlock = _currentFunclet.AppendBasicBlock("endif");
builder.BuildCondBr(eqZ, notFatBranch, fatBranch);

// then
builder.PositionAtEnd(notFatBranch);
ExceptionRegion currentTryRegion = GetCurrentTryRegion();
Expand Down Expand Up @@ -3686,11 +3691,33 @@ private void ImportBinaryOperation(ILOpcode opcode)
result = _builder.BuildXor(left, right, "xor");
break;

// TODO: Overflow checks
case ILOpcode.add_ovf:
Debug.Assert(type.Category == TypeFlags.Int32 || type.Category == TypeFlags.Int64);
if (type.Category == TypeFlags.Int32)
{
BuildAddOverflowChecksForSize(ref Ovf32Function, left, right, LLVMTypeRef.Int32, BuildConstInt32(int.MaxValue), BuildConstInt32(int.MinValue), true);
}
else
{
BuildAddOverflowChecksForSize(ref Ovf64Function, left, right, LLVMTypeRef.Int64, BuildConstInt64(long.MaxValue), BuildConstInt64(long.MinValue), true);
}

result = _builder.BuildAdd(left, right, "add");
break;
case ILOpcode.add_ovf_un:
Debug.Assert(type.Category == TypeFlags.UInt32 || type.Category == TypeFlags.Int32 || type.Category == TypeFlags.UInt64 || type.Category == TypeFlags.Int64 || type.Category == TypeFlags.Pointer);
if (type.Category == TypeFlags.UInt32 || type.Category == TypeFlags.Int32 || type.Category == TypeFlags.Pointer)
{
BuildAddOverflowChecksForSize(ref OvfUn32Function, left, right, LLVMTypeRef.Int32, BuildConstUInt32(uint.MaxValue), BuildConstInt32(int.MinValue), false);
}
else
{
BuildAddOverflowChecksForSize(ref OvfUn64Function, left, right, LLVMTypeRef.Int64, BuildConstUInt64(ulong.MaxValue), BuildConstInt64(long.MinValue), false);
}

result = _builder.BuildAdd(left, right, "add");
break;
// TODO: Overflow checks
case ILOpcode.sub_ovf:
case ILOpcode.sub_ovf_un:
result = _builder.BuildSub(left, right, "sub");
Expand All @@ -3714,6 +3741,63 @@ private void ImportBinaryOperation(ILOpcode opcode)
PushExpression(kind, "binop", result, type);
}

void BuildAddOverflowChecksForSize(ref LLVMValueRef llvmCheckFunction, LLVMValueRef left, LLVMValueRef right, LLVMTypeRef sizeTypeRef, LLVMValueRef maxValue, LLVMValueRef minValue, bool signed)
{
if (llvmCheckFunction.Handle == IntPtr.Zero)
{
// create function name for each of the 4 combinations signed/unsigned, 32/64 bit
string throwFuncName = "corert.throwovf" + (signed ? "" : "un") + (sizeTypeRef.IntWidth == 32 ? "32" : "64");
llvmCheckFunction = Module.AddFunction(throwFuncName, LLVMTypeRef.CreateFunction(LLVMTypeRef.Void, new LLVMTypeRef[] { LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), sizeTypeRef, sizeTypeRef }, false));
var leftOp = llvmCheckFunction.GetParam(1);
var rightOp = llvmCheckFunction.GetParam(2);
var builder = Context.CreateBuilder();
var block = llvmCheckFunction.AppendBasicBlock("Block");
builder.PositionAtEnd(block);
LLVMBasicBlockRef elseBlock = default;
if (signed) // signed ops need a separate test for the left side being negative
{
var gtZeroCmp = builder.BuildICmp(LLVMIntPredicate.LLVMIntSGT, leftOp,
LLVMValueRef.CreateConstInt(sizeTypeRef, 0, false));
LLVMBasicBlockRef thenBlock = llvmCheckFunction.AppendBasicBlock("posOvfBlock");
elseBlock = llvmCheckFunction.AppendBasicBlock("negOvfBlock");
builder.BuildCondBr(gtZeroCmp, thenBlock, elseBlock);
builder.PositionAtEnd(thenBlock);
}
LLVMBasicBlockRef ovfBlock = llvmCheckFunction.AppendBasicBlock("ovfBlock");
LLVMBasicBlockRef noOvfBlock = llvmCheckFunction.AppendBasicBlock("noOvfBlock");
// b > int.MaxValue - a
BuildOverflowCheck(builder, leftOp, rightOp, maxValue,
signed ? LLVMIntPredicate.LLVMIntSGT : LLVMIntPredicate.LLVMIntUGT, ovfBlock, noOvfBlock);

builder.PositionAtEnd(ovfBlock);

ThrowException(builder, "ThrowHelpers", "ThrowOverflowException", llvmCheckFunction);

builder.PositionAtEnd(noOvfBlock);
LLVMBasicBlockRef opBlock = llvmCheckFunction.AppendBasicBlock("opBlock");
builder.BuildBr(opBlock);

if (signed)
{
builder.PositionAtEnd(elseBlock);
// b < int.MinValue - a
BuildOverflowCheck(builder, leftOp, rightOp, minValue, LLVMIntPredicate.LLVMIntSLT, ovfBlock, noOvfBlock);
}
builder.PositionAtEnd(opBlock);
builder.BuildRetVoid();
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), left, right }, ref nextInstrBlock);
}

private void BuildOverflowCheck(LLVMBuilderRef builder, LLVMValueRef left, LLVMValueRef right, LLVMValueRef limitValueRef, LLVMIntPredicate predicate, LLVMBasicBlockRef ovfBlock, LLVMBasicBlockRef noOvfBlock)
{
LLVMValueRef sub = builder.BuildSub(limitValueRef, left);
LLVMValueRef ovfTest = builder.BuildICmp(predicate, right, sub);
builder.BuildCondBr(ovfTest, ovfBlock, noOvfBlock);
}

private TypeDesc WidenBytesAndShorts(TypeDesc type)
{
switch (type.Category)
Expand Down Expand Up @@ -3866,6 +3950,7 @@ private void ImportCompareOperation(ILOpcode opcode)

private void ImportConvert(WellKnownType wellKnownType, bool checkOverflow, bool unsigned)
{
//TODO checkOverflow
StackEntry value = _stack.Pop();
TypeDesc destType = GetWellKnownType(wellKnownType);

Expand Down Expand Up @@ -4155,7 +4240,7 @@ private void ImportThrow()

private LLVMBasicBlockRef GetOrCreateUnreachableBlock()
{
if(_funcletUnreachableBlocks.TryGetValue(_currentFunclet.Handle, out LLVMBasicBlockRef unreachableBlock))
if (_funcletUnreachableBlocks.TryGetValue(_currentFunclet.Handle, out LLVMBasicBlockRef unreachableBlock))
{
return unreachableBlock;
}
Expand Down Expand Up @@ -4413,7 +4498,7 @@ ISymbolNode GetGenericLookupHelperAndAddReference(ReadyToRunHelperId helperId, o
}
else
{
Debug.Assert(_method.RequiresInstMethodTableArg() || _method.AcquiresInstMethodTableFromThis());
Debug.Assert(_method.RequiresInstMethodTableArg() || _method.AcquiresInstMethodTableFromThis());
node = _compilation.NodeFactory.ReadyToRunHelperFromTypeLookup(helperId, helperArg, _method.OwningType);
helper = GetOrCreateLLVMFunction(node.GetMangledName(_compilation.NameMangler),
LLVMTypeRef.CreateFunction(retType, helperArgs.ToArray(), false));
Expand Down Expand Up @@ -4686,7 +4771,7 @@ LLVMValueRef GetGenericContext()
{
LLVMValueRef typedAddress;
LLVMValueRef thisPtr;

typedAddress = CastIfNecessary(_builder, _currentFunclet.GetParam(0),
LLVMTypeRef.CreatePointer(LLVMTypeRef.CreatePointer(LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), 0), 0));
thisPtr = _builder.BuildLoad( typedAddress, "loadThis");
Expand Down Expand Up @@ -4980,7 +5065,7 @@ public ExpressionEntry OutputCodeForGetThreadStaticBaseForType(LLVMValueRef thre
var threadStaticIndexPtr = _builder.BuildPointerCast(threadStaticIndex,
LLVMTypeRef.CreatePointer(LLVMTypeRef.CreatePointer(LLVMTypeRef.Int32, 0), 0), "tsiPtr");
LLVMValueRef typeTlsIndexPtr =
_builder.BuildGEP( threadStaticIndexPtr, new LLVMValueRef[] { BuildConstInt32(1) }, "typeTlsIndexPtr"); // index is the second field after the ptr.
_builder.BuildGEP(threadStaticIndexPtr, new LLVMValueRef[] { BuildConstInt32(1) }, "typeTlsIndexPtr"); // index is the second field after the ptr.

StackEntry typeManagerSlotEntry = new LoadExpressionEntry(StackValueKind.ValueType, "typeManagerSlot", threadStaticIndexPtr, GetWellKnownType(WellKnownType.Int32));
StackEntry tlsIndexExpressionEntry = new LoadExpressionEntry(StackValueKind.ValueType, "typeTlsIndex", typeTlsIndexPtr, GetWellKnownType(WellKnownType.Int32));
Expand Down
Expand Up @@ -119,6 +119,10 @@ public static void CompileMethod(WebAssemblyCodegenCompilation compilation, WebA
static LLVMValueRef NullRefFunction = default(LLVMValueRef);
static LLVMValueRef CkFinite32Function = default(LLVMValueRef);
static LLVMValueRef CkFinite64Function = default(LLVMValueRef);
static LLVMValueRef Ovf32Function = default(LLVMValueRef);
static LLVMValueRef OvfUn32Function = default(LLVMValueRef);
static LLVMValueRef Ovf64Function = default(LLVMValueRef);
static LLVMValueRef OvfUn64Function = default(LLVMValueRef);
public static LLVMValueRef GxxPersonality = default(LLVMValueRef);
public static LLVMTypeRef GxxPersonalityType = default(LLVMTypeRef);

Expand Down
160 changes: 159 additions & 1 deletion tests/src/Simple/HelloWasm/Program.cs
Expand Up @@ -351,6 +351,8 @@ private static unsafe int Main(string[] args)
TestCkFinite();
#endif

TestIntOverflows();

// This test should remain last to get other results before stopping the debugger
PrintLine("Debugger.Break() test: Ok if debugger is open and breaks.");
System.Diagnostics.Debugger.Break();
Expand Down Expand Up @@ -1758,7 +1760,6 @@ static void TestThrowIfNull()
{
success = false;
}

try
{
var f = c.ToString(); //method access
Expand Down Expand Up @@ -1813,6 +1814,163 @@ private static unsafe bool CkFinite64(ulong value)
}
#endif

static void TestIntOverflows()
{
TestSignedIntAddOvf();

TestSignedLongAddOvf();

TestUnsignedIntAddOvf();

TestUnsignedLongAddOvf();
}

private static void TestSignedLongAddOvf()
{
StartTest("Test long add overflows");
bool thrown;
long op64l = 1;
long op64r = long.MaxValue;
thrown = false;
try
{
long res = checked(op64l + op64r);
}
catch (OverflowException)
{
thrown = true;
}
if (!thrown)
{
FailTest("exception not thrown for signed i64 addition of +ve number");
return;
}
thrown = false;
op64l = long.MinValue; // add negative to overflow below the MinValue
op64r = -1;
try
{
long res = checked(op64l + op64r);
}
catch (OverflowException)
{
thrown = true;
}
if (!thrown)
{
FailTest("exception not thrown for signed i64 addition of -ve number");
return;
}
EndTest(true);
}

private static void TestSignedIntAddOvf()
{
StartTest("Test int add overflows");
bool thrown;
int op32l = 1;
int op32r = 2;
if (checked(op32l + op32r) != 3)
{
FailTest("No overflow failed"); // check not always throwing an exception
return;
}
op32l = 1;
op32r = int.MaxValue;
thrown = false;
try
{
int res = checked(op32l + op32r);
}
catch (OverflowException)
{
thrown = true;
}
if (!thrown)
{
FailTest("exception not thrown for signed i32 addition of +ve number");
return;
}

thrown = false;
op32l = int.MinValue; // add negative to overflow below the MinValue
op32r = -1;
try
{
int res = checked(op32l + op32r);
}
catch (OverflowException)
{
thrown = true;
}
if (!thrown)
{
FailTest("exception not thrown for signed i32 addition of -ve number");
return;
}
PassTest();
}

private static void TestUnsignedIntAddOvf()
{
StartTest("Test uint add overflows");
bool thrown;
uint op32l = 1;
uint op32r = 2;
if (checked(op32l + op32r) != 3)
{
FailTest("No overflow failed"); // check not always throwing an exception
return;
}
op32l = 1;
op32r = uint.MaxValue;
thrown = false;
try
{
uint res = checked(op32l + op32r);
}
catch (OverflowException)
{
thrown = true;
}
if (!thrown)
{
FailTest("exception not thrown for unsigned i32 addition of +ve number");
return;
}
PassTest();
}

private static void TestUnsignedLongAddOvf()
{
StartTest("Test ulong add overflows");
bool thrown;
ulong op64l = 1;
ulong op64r = 2;
if (checked(op64l + op64r) != 3)
{
FailTest("No overflow failed"); // check not always throwing an exception
return;
}
op64l = 1;
op64r = ulong.MaxValue;
thrown = false;
try
{
ulong res = checked(op64l + op64r);
}
catch (OverflowException)
{
thrown = true;
}
if (!thrown)
{
FailTest("exception not thrown for unsigned i64 addition of +ve number");
return;
}
PassTest();
}

static ushort ReadUInt16()
{
// something with MSB set
Expand Down

0 comments on commit e4131da

Please sign in to comment.