Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions compiler/asserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"go/token"
"go/types"

"golang.org/x/tools/go/ssa"
"tinygo.org/x/go-llvm"
)

Expand Down Expand Up @@ -151,31 +152,36 @@ func (b *builder) createChanBoundsCheck(elementSize uint64, bufSize llvm.Value,
// createNilCheck checks whether the given pointer is nil, and panics if it is.
// It has no effect in well-behaved programs, but makes sure no uncaught nil
// pointer dereferences exist in valid Go code.
func (b *builder) createNilCheck(ptr llvm.Value, blockPrefix string) {
func (b *builder) createNilCheck(inst ssa.Value, ptr llvm.Value, blockPrefix string) {
// Check whether we need to emit this check at all.
if !ptr.IsAGlobalValue().IsNil() {
return
}

// Compare against nil.
var isnil llvm.Value
if ptr.Type().PointerAddressSpace() == 0 {
// Do the nil check using the isnil builtin, which marks the parameter
// as nocapture.
// The reason it has to go through a builtin, is that a regular icmp
// instruction may capture the pointer in LLVM semantics, see
// https://reviews.llvm.org/D60047 for details. Pointer capturing
// unfortunately breaks escape analysis, so we use this trick to let the
// functionattr pass know that this pointer doesn't really escape.
ptr = b.CreateBitCast(ptr, b.i8ptrType, "")
isnil = b.createRuntimeCall("isnil", []llvm.Value{ptr}, "")
} else {
// Do the nil check using a regular icmp. This can happen with function
// pointers on AVR, which don't benefit from escape analysis anyway.
nilptr := llvm.ConstPointerNull(ptr.Type())
isnil = b.CreateICmp(llvm.IntEQ, ptr, nilptr, "")
switch inst := inst.(type) {
case *ssa.IndexAddr:
// This pointer is the result of an index operation into a slice or
// array. Such slices/arrays are already bounds checked so the pointer
// must be a valid (non-nil) pointer. No nil checking is necessary.
return
case *ssa.Convert:
// This is a pointer that comes from a conversion from unsafe.Pointer.
// Don't do nil checking because this is unsafe code and the code should
// know what it is doing.
// Note: all *ssa.Convert instructions that result in a pointer must
// come from unsafe.Pointer. Testing here for unsafe.Pointer to be sure.
if inst.X.Type() == types.Typ[types.UnsafePointer] {
return
}
}

// Compare against nil.
// We previously used a hack to make sure this wouldn't break escape
// analysis, but this is not necessary anymore since
// https://reviews.llvm.org/D60047 has been merged.
nilptr := llvm.ConstPointerNull(ptr.Type())
isnil := b.CreateICmp(llvm.IntEQ, ptr, nilptr, "")

// Emit the nil check in IR.
b.createRuntimeAssert(isnil, blockPrefix, "nilPanic")
}
Expand Down
82 changes: 70 additions & 12 deletions compiler/calls.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package compiler

import (
"go/types"

"tinygo.org/x/go-llvm"
)

Expand All @@ -11,6 +13,16 @@ import (
// a struct contains more fields, it is passed as a struct without expanding.
const MaxFieldsPerParam = 3

// paramFlags identifies parameter attributes for flags. Most importantly, it
// determines which parameters are dereferenceable_or_null and which aren't.
type paramFlags uint8

const (
// Parameter may have the deferenceable_or_null attribute. This attribute
// cannot be applied to unsafe.Pointer and to the data pointer of slices.
paramIsDeferenceableOrNull = 1 << iota
)

// createCall creates a new call to runtime.<fnName> with the given arguments.
func (b *builder) createRuntimeCall(fnName string, args []llvm.Value, name string) llvm.Value {
fullName := "runtime." + fnName
Expand All @@ -36,19 +48,19 @@ func (b *builder) createCall(fn llvm.Value, args []llvm.Value, name string) llvm

// Expand an argument type to a list that can be used in a function call
// parameter list.
func expandFormalParamType(t llvm.Type) []llvm.Type {
func expandFormalParamType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
switch t.TypeKind() {
case llvm.StructTypeKind:
fields := flattenAggregateType(t)
fields, fieldFlags := flattenAggregateType(t, goType)
if len(fields) <= MaxFieldsPerParam {
return fields
return fields, fieldFlags
} else {
// failed to lower
return []llvm.Type{t}
return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
}
default:
// TODO: split small arrays
return []llvm.Type{t}
return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
}
}

Expand Down Expand Up @@ -79,7 +91,7 @@ func (b *builder) expandFormalParamOffsets(t llvm.Type) []uint64 {
func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
switch v.Type().TypeKind() {
case llvm.StructTypeKind:
fieldTypes := flattenAggregateType(v.Type())
fieldTypes, _ := flattenAggregateType(v.Type(), nil)
if len(fieldTypes) <= MaxFieldsPerParam {
fields := b.flattenAggregate(v)
if len(fields) != len(fieldTypes) {
Expand All @@ -98,17 +110,62 @@ func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {

// Try to flatten a struct type to a list of types. Returns a 1-element slice
// with the passed in type if this is not possible.
func flattenAggregateType(t llvm.Type) []llvm.Type {
func flattenAggregateType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
typeFlags := getTypeFlags(goType)
switch t.TypeKind() {
case llvm.StructTypeKind:
fields := make([]llvm.Type, 0, t.StructElementTypesCount())
for _, subfield := range t.StructElementTypes() {
subfields := flattenAggregateType(subfield)
fieldFlags := make([]paramFlags, 0, cap(fields))
for i, subfield := range t.StructElementTypes() {
subfields, subfieldFlags := flattenAggregateType(subfield, extractSubfield(goType, i))
for i := range subfieldFlags {
subfieldFlags[i] |= typeFlags
}
fields = append(fields, subfields...)
fieldFlags = append(fieldFlags, subfieldFlags...)
}
return fields
return fields, fieldFlags
default:
return []llvm.Type{t}, []paramFlags{typeFlags}
}
}

// getTypeFlags returns the type flags for a given type. It will not recurse
// into sub-types (such as in structs).
func getTypeFlags(t types.Type) paramFlags {
if t == nil {
return 0
}
switch t.Underlying().(type) {
case *types.Pointer:
// Pointers in Go must either point to an object or be nil.
return paramIsDeferenceableOrNull
case *types.Chan, *types.Map:
// Channels and maps are implemented as pointers pointing to some
// object, and follow the same rules as *types.Pointer.
return paramIsDeferenceableOrNull
default:
return 0
}
}

// extractSubfield extracts a field from a struct, or returns null if this is
// not a struct and thus no subfield can be obtained.
func extractSubfield(t types.Type, field int) types.Type {
if t == nil {
return nil
}
switch t := t.Underlying().(type) {
case *types.Struct:
return t.Field(field).Type()
case *types.Interface, *types.Slice, *types.Basic, *types.Signature:
// These Go types are (sometimes) implemented as LLVM structs but can't
// really be split further up in Go (with the possible exception of
// complex numbers).
return nil
default:
return []llvm.Type{t}
// This should be unreachable.
panic("cannot split subfield: " + t.String())
}
}

Expand Down Expand Up @@ -169,7 +226,8 @@ func (b *builder) collapseFormalParam(t llvm.Type, fields []llvm.Value) llvm.Val
func (b *builder) collapseFormalParamInternal(t llvm.Type, fields []llvm.Value) (llvm.Value, []llvm.Value) {
switch t.TypeKind() {
case llvm.StructTypeKind:
if len(flattenAggregateType(t)) <= MaxFieldsPerParam {
flattened, _ := flattenAggregateType(t, nil)
if len(flattened) <= MaxFieldsPerParam {
value := llvm.ConstNull(t)
for i, subtyp := range t.StructElementTypes() {
structField, remaining := b.collapseFormalParamInternal(subtyp, fields)
Expand Down
38 changes: 28 additions & 10 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,6 @@ func Compile(pkgName string, machine llvm.TargetMachine, config *compileopts.Con
c.mod.NamedFunction("runtime.alloc").AddAttributeAtIndex(0, getAttr(attrName))
}

// See createNilCheck in asserts.go.
c.mod.NamedFunction("runtime.isnil").AddAttributeAtIndex(1, nocapture)

// On *nix systems, the "abort" functuion in libc is used to handle fatal panics.
// Mark it as noreturn so LLVM can optimize away code.
if abort := c.mod.NamedFunction("abort"); !abort.IsNil() && abort.IsDeclaration() {
Expand Down Expand Up @@ -750,17 +747,20 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
}

var paramTypes []llvm.Type
var paramTypeVariants []paramFlags
for _, param := range f.Params {
paramType := c.getLLVMType(param.Type())
paramTypeFragments := expandFormalParamType(paramType)
paramTypeFragments, paramTypeFragmentVariants := expandFormalParamType(paramType, param.Type())
paramTypes = append(paramTypes, paramTypeFragments...)
paramTypeVariants = append(paramTypeVariants, paramTypeFragmentVariants...)
}

// Add an extra parameter as the function context. This context is used in
// closures and bound methods, but should be optimized away when not used.
if !f.IsExported() {
paramTypes = append(paramTypes, c.i8ptrType) // context
paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine
paramTypeVariants = append(paramTypeVariants, 0, 0)
}

fnType := llvm.FunctionType(retType, paramTypes, false)
Expand All @@ -771,6 +771,23 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
f.LLVMFn = llvm.AddFunction(c.mod, name, fnType)
}

dereferenceableOrNullKind := llvm.AttributeKindID("dereferenceable_or_null")
for i, typ := range paramTypes {
if paramTypeVariants[i]&paramIsDeferenceableOrNull == 0 {
continue
}
if typ.TypeKind() == llvm.PointerTypeKind {
el := typ.ElementType()
size := c.targetData.TypeAllocSize(el)
if size == 0 {
// dereferenceable_or_null(0) appears to be illegal in LLVM.
continue
}
dereferenceableOrNull := c.ctx.CreateEnumAttribute(dereferenceableOrNullKind, size)
f.LLVMFn.AddAttributeAtIndex(i+1, dereferenceableOrNull)
}
}

// External/exported functions may not retain pointer values.
// https://golang.org/cmd/cgo/#hdr-Passing_pointers
if f.IsExported() {
Expand Down Expand Up @@ -901,7 +918,8 @@ func (b *builder) createFunctionDefinition() {
for _, param := range b.fn.Params {
llvmType := b.getLLVMType(param.Type())
fields := make([]llvm.Value, 0, 1)
for range expandFormalParamType(llvmType) {
fieldFragments, _ := expandFormalParamType(llvmType, nil)
for range fieldFragments {
fields = append(fields, b.fn.LLVMFn.Param(llvmParamIndex))
llvmParamIndex++
}
Expand Down Expand Up @@ -1133,7 +1151,7 @@ func (b *builder) createInstruction(instr ssa.Instruction) {
case *ssa.Store:
llvmAddr := b.getValue(instr.Addr)
llvmVal := b.getValue(instr.Val)
b.createNilCheck(llvmAddr, "store")
b.createNilCheck(instr.Addr, llvmAddr, "store")
if b.targetData.TypeAllocSize(llvmVal.Type()) == 0 {
// nothing to store
return
Expand Down Expand Up @@ -1381,7 +1399,7 @@ func (b *builder) createFunctionCall(instr *ssa.CallCommon) (llvm.Value, error)
// This is a func value, which cannot be called directly. We have to
// extract the function pointer and context first from the func value.
callee, context = b.decodeFuncValue(value, instr.Value.Type().Underlying().(*types.Signature))
b.createNilCheck(callee, "fpcall")
b.createNilCheck(instr.Value, callee, "fpcall")
}

var params []llvm.Value
Expand Down Expand Up @@ -1527,7 +1545,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) {
// > For an operand x of type T, the address operation &x generates a
// > pointer of type *T to x. [...] If the evaluation of x would cause a
// > run-time panic, then the evaluation of &x does too.
b.createNilCheck(val, "gep")
b.createNilCheck(expr.X, val, "gep")
// Do a GEP on the pointer to get the field address.
indices := []llvm.Value{
llvm.ConstInt(b.ctx.Int32Type(), 0, false),
Expand Down Expand Up @@ -1575,7 +1593,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) {
// > generates a pointer of type *T to x. [...] If the
// > evaluation of x would cause a run-time panic, then the
// > evaluation of &x does too.
b.createNilCheck(bufptr, "gep")
b.createNilCheck(expr.X, bufptr, "gep")
default:
return llvm.Value{}, b.makeError(expr.Pos(), "todo: indexaddr: "+typ.String())
}
Expand Down Expand Up @@ -2567,7 +2585,7 @@ func (b *builder) createUnOp(unop *ssa.UnOp) (llvm.Value, error) {
}
return b.CreateBitCast(fn, b.i8ptrType, ""), nil
} else {
b.createNilCheck(x, "deref")
b.createNilCheck(unop.X, x, "deref")
load := b.CreateLoad(x, "")
return load, nil
}
Expand Down
6 changes: 4 additions & 2 deletions compiler/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,13 @@ func (c *compilerContext) getRawFuncType(typ *types.Signature) llvm.Type {
// The receiver is not an interface, but a i8* type.
recv = c.i8ptrType
}
paramTypes = append(paramTypes, expandFormalParamType(recv)...)
recvFragments, _ := expandFormalParamType(recv, nil)
paramTypes = append(paramTypes, recvFragments...)
}
for i := 0; i < typ.Params().Len(); i++ {
subType := c.getLLVMType(typ.Params().At(i).Type())
paramTypes = append(paramTypes, expandFormalParamType(subType)...)
paramTypeFragments, _ := expandFormalParamType(subType, nil)
paramTypes = append(paramTypes, paramTypeFragments...)
}
// All functions take these parameters at the end.
paramTypes = append(paramTypes, c.i8ptrType) // context
Expand Down
2 changes: 1 addition & 1 deletion compiler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ func (c *compilerContext) getInterfaceInvokeWrapper(f *ir.Function) llvm.Value {

// Get the expanded receiver type.
receiverType := c.getLLVMType(f.Params[0].Type())
expandedReceiverType := expandFormalParamType(receiverType)
expandedReceiverType, _ := expandFormalParamType(receiverType, nil)

// Does this method even need any wrapping?
if len(expandedReceiverType) == 1 && receiverType.TypeKind() == llvm.PointerTypeKind {
Expand Down
4 changes: 2 additions & 2 deletions compiler/volatile.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
// runtime/volatile.LoadT().
func (b *builder) createVolatileLoad(instr *ssa.CallCommon) (llvm.Value, error) {
addr := b.getValue(instr.Args[0])
b.createNilCheck(addr, "deref")
b.createNilCheck(instr.Args[0], addr, "deref")
val := b.CreateLoad(addr, "")
val.SetVolatile(true)
return val, nil
Expand All @@ -23,7 +23,7 @@ func (b *builder) createVolatileLoad(instr *ssa.CallCommon) (llvm.Value, error)
func (b *builder) createVolatileStore(instr *ssa.CallCommon) (llvm.Value, error) {
addr := b.getValue(instr.Args[0])
val := b.getValue(instr.Args[1])
b.createNilCheck(addr, "deref")
b.createNilCheck(instr.Args[0], addr, "deref")
store := b.CreateStore(val, addr)
store.SetVolatile(true)
return llvm.Value{}, nil
Expand Down
8 changes: 0 additions & 8 deletions src/runtime/panic.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ func _recover() interface{} {
return nil
}

// See emitNilCheck in compiler/asserts.go.
// This function is a dummy function that has its first and only parameter
// marked 'nocapture' to work around a limitation in LLVM: a regular pointer
// comparison captures the pointer.
func isnil(ptr *uint8) bool {
return ptr == nil
}

// Panic when trying to dereference a nil pointer.
func nilPanic() {
runtimePanic("nil pointer dereference")
Expand Down
15 changes: 2 additions & 13 deletions transform/func-lowering.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,8 @@ func LowerFuncValues(mod llvm.Module) {
panic("expected inttoptr")
}
for _, ptrUse := range getUses(callIntPtr) {
if !ptrUse.IsABitCastInst().IsNil() {
for _, bitcastUse := range getUses(ptrUse) {
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().IsAFunction().IsNil() {
panic("expected a call instruction")
}
switch bitcastUse.CalledValue().Name() {
case "runtime.isnil":
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false))
bitcastUse.EraseFromParentAsInstruction()
default:
panic("expected a call to runtime.isnil")
}
}
if !ptrUse.IsAICmpInst().IsNil() {
ptrUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false))
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr {
addFuncLoweringSwitch(mod, builder, funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
return builder.CreateCall(funcPtr, params, "")
Expand Down
Loading