Skip to content

Commit

Permalink
transform (coroutines): fix memory corruption for tail calls that ref…
Browse files Browse the repository at this point in the history
…erence stack allocations

This change fixes a bug in which `alloca` memory lifetimes would not extend past the suspend of an asynchronous tail call.
This would typically manifest as memory corruption, and could happen with or without normal suspending calls within the function.
  • Loading branch information
niaow authored and aykevl committed Sep 21, 2021
1 parent a116fd0 commit ecd8c2d
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 12 deletions.
38 changes: 33 additions & 5 deletions transform/coroutines.go
Expand Up @@ -600,11 +600,11 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
continue
}

if len(fn.normalCalls) == 0 {
// No suspend points. Lower without turning it into a coroutine.
if len(fn.normalCalls) == 0 && fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() {
// No suspend points or stack allocations. Lower without turning it into a coroutine.
c.lowerFuncFast(fn)
} else {
// There are suspend points, so it is necessary to turn this into a coroutine.
// There are suspend points or stack allocations, so it is necessary to turn this into a coroutine.
c.lowerFuncCoro(fn)
}
}
Expand Down Expand Up @@ -827,6 +827,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
}

// Lower returns.
var postTail llvm.BasicBlock
for _, ret := range fn.returns {
// Get terminator instruction.
terminator := ret.block.LastInstruction()
Expand Down Expand Up @@ -886,10 +887,37 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
call.EraseFromParentAsInstruction()
}

// Replace terminator with branch to cleanup.
// Replace terminator with a branch to the exit.
var exit llvm.BasicBlock
if ret.kind == returnNormal || ret.kind == returnVoid || fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() {
// Exit through the cleanup path.
exit = cleanup
} else {
if postTail.IsNil() {
// Create a path with a suspend that never reawakens.
postTail = c.ctx.AddBasicBlock(fn.fn, "post.tail")
c.builder.SetInsertPointAtEnd(postTail)
// %coro.save = call token @llvm.coro.save(i8* %coro.state)
save := c.builder.CreateCall(c.coroSave, []llvm.Value{coroState}, "coro.save")
// %call.suspend = llvm.coro.suspend(token %coro.save, i1 false)
// switch i8 %call.suspend, label %suspend [i8 0, label %wakeup
// i8 1, label %cleanup]
suspendValue := c.builder.CreateCall(c.coroSuspend, []llvm.Value{save, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "call.suspend")
sw := c.builder.CreateSwitch(suspendValue, suspend, 2)
unreachableBlock := c.ctx.AddBasicBlock(fn.fn, "unreachable")
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), unreachableBlock)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), cleanup)
c.builder.SetInsertPointAtEnd(unreachableBlock)
c.builder.CreateUnreachable()
}

// Exit through a permanent suspend.
exit = postTail
}

terminator.EraseFromParentAsInstruction()
c.builder.SetInsertPointAtEnd(ret.block)
c.builder.CreateBr(cleanup)
c.builder.CreateBr(exit)
}

// Lower regular calls.
Expand Down
34 changes: 33 additions & 1 deletion transform/testdata/coroutines.ll
Expand Up @@ -86,11 +86,43 @@ entry:
}

; Normal function which should not be transformed.
define void @doNothing(i8*, i8*) {
define void @doNothing(i8*, i8* %parentHandle) {
entry:
ret void
}

; Regression test: ensure that a tail call does not destroy the frame while it is still in use.
; Previously, the tail-call lowering transform would branch to the cleanup block after usePtr.
; This caused the lifetime of %a to be incorrectly reduced, and allowed the coroutine lowering transform to keep %a on the stack.
; After a suspend %a would be used, resulting in memory corruption.
define i8 @coroutineTailRegression(i8*, i8* %parentHandle) {
entry:
%a = alloca i8
store i8 5, i8* %a
%val = call i8 @usePtr(i8* %a, i8* undef, i8* null)
ret i8 %val
}

; Regression test: ensure that stack allocations alive during a suspend end up on the heap.
; This used to not be transformed to a coroutine, keeping %a on the stack.
; After a suspend %a would be used, resulting in memory corruption.
define i8 @allocaTailRegression(i8*, i8* %parentHandle) {
entry:
%a = alloca i8
call void @sleep(i64 1000000, i8* undef, i8* null)
store i8 5, i8* %a
%val = call i8 @usePtr(i8* %a, i8* undef, i8* null)
ret i8 %val
}

; usePtr uses a pointer after a suspend.
define i8 @usePtr(i8*, i8*, i8* %parentHandle) {
entry:
call void @sleep(i64 1000000, i8* undef, i8* null)
%val = load i8, i8* %0
ret i8 %val
}

; Goroutine that sleeps and does nothing.
; Should be a void tail call.
define void @sleepGoroutine(i8*, i8* %parentHandle) {
Expand Down
128 changes: 122 additions & 6 deletions transform/testdata/coroutines.out.ll
Expand Up @@ -45,7 +45,7 @@ entry:
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
%ret.ptr.bitcast = bitcast i8* %ret.ptr to i32*
store i32 %0, i32* %ret.ptr.bitcast
store i32 %0, i32* %ret.ptr.bitcast, align 4
call void @sleep(i64 %1, i8* undef, i8* %parentHandle)
ret i32 undef
}
Expand Down Expand Up @@ -84,7 +84,7 @@ entry:
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
%ret.ptr.bitcast = bitcast i8* %ret.ptr to i32*
store i32 %0, i32* %ret.ptr.bitcast
store i32 %0, i32* %ret.ptr.bitcast, align 4
%ret.alternate = call i8* @runtime.alloc(i32 4, i8* undef, i8* undef)
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %ret.alternate, i8* undef, i8* undef)
%4 = call i32 @delayedValue(i32 %1, i64 %2, i8* undef, i8* %parentHandle)
Expand All @@ -93,7 +93,7 @@ entry:

define i1 @coroutine(i32 %0, i64 %1, i8* %2, i8* %parentHandle) {
entry:
%call.return = alloca i32
%call.return = alloca i32, align 4
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
Expand All @@ -116,10 +116,10 @@ entry:
]

wakeup: ; preds = %entry
%4 = load i32, i32* %call.return
%4 = load i32, i32* %call.return, align 4
call void @llvm.lifetime.end.p0i8(i64 4, i8* %call.return.bitcast)
%5 = icmp eq i32 %4, 0
store i1 %5, i1* %task.retPtr.bitcast
store i1 %5, i1* %task.retPtr.bitcast, align 1
call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current2, i8* %task.state.parent, i8* undef, i8* undef)
br label %cleanup

Expand All @@ -133,11 +133,127 @@ cleanup: ; preds = %entry, %wakeup
br label %suspend
}

define void @doNothing(i8* %0, i8* %1) {
define void @doNothing(i8* %0, i8* %parentHandle) {
entry:
ret void
}

define i8 @coroutineTailRegression(i8* %0, i8* %parentHandle) {
entry:
%a = alloca i8, align 1
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
store i8 5, i8* %a, align 1
%coro.state.restore = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef)
%val = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle)
br label %post.tail

suspend: ; preds = %post.tail, %cleanup
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
ret i8 undef

cleanup: ; preds = %post.tail
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
br label %suspend

post.tail: ; preds = %entry
%coro.save = call token @llvm.coro.save(i8* %coro.state)
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
switch i8 %call.suspend, label %suspend [
i8 0, label %unreachable
i8 1, label %cleanup
]

unreachable: ; preds = %post.tail
unreachable
}

define i8 @allocaTailRegression(i8* %0, i8* %parentHandle) {
entry:
%a = alloca i8, align 1
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
%coro.save1 = call token @llvm.coro.save(i8* %coro.state)
%call.suspend2 = call i8 @llvm.coro.suspend(token %coro.save1, i1 false)
switch i8 %call.suspend2, label %suspend [
i8 0, label %wakeup
i8 1, label %cleanup
]

wakeup: ; preds = %entry
store i8 5, i8* %a, align 1
%1 = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef)
%2 = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle)
br label %post.tail

suspend: ; preds = %entry, %post.tail, %cleanup
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
ret i8 undef

cleanup: ; preds = %entry, %post.tail
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
br label %suspend

post.tail: ; preds = %wakeup
%coro.save = call token @llvm.coro.save(i8* %coro.state)
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
switch i8 %call.suspend, label %suspend [
i8 0, label %unreachable
i8 1, label %cleanup
]

unreachable: ; preds = %post.tail
unreachable
}

define i8 @usePtr(i8* %0, i8* %1, i8* %parentHandle) {
entry:
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
%coro.save = call token @llvm.coro.save(i8* %coro.state)
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
switch i8 %call.suspend, label %suspend [
i8 0, label %wakeup
i8 1, label %cleanup
]

wakeup: ; preds = %entry
%2 = load i8, i8* %0, align 1
store i8 %2, i8* %task.retPtr, align 1
call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
br label %cleanup

suspend: ; preds = %entry, %cleanup
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
ret i8 undef

cleanup: ; preds = %entry, %wakeup
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
br label %suspend
}

define void @sleepGoroutine(i8* %0, i8* %parentHandle) {
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
Expand Down

0 comments on commit ecd8c2d

Please sign in to comment.