diff --git a/compiler/asserts.go b/compiler/asserts.go index 6b25478177..f1ae259eaa 100644 --- a/compiler/asserts.go +++ b/compiler/asserts.go @@ -8,6 +8,7 @@ import ( "go/token" "go/types" + "golang.org/x/tools/go/ssa" "tinygo.org/x/go-llvm" ) @@ -151,31 +152,36 @@ func (b *builder) createChanBoundsCheck(elementSize uint64, bufSize llvm.Value, // createNilCheck checks whether the given pointer is nil, and panics if it is. // It has no effect in well-behaved programs, but makes sure no uncaught nil // pointer dereferences exist in valid Go code. -func (b *builder) createNilCheck(ptr llvm.Value, blockPrefix string) { +func (b *builder) createNilCheck(inst ssa.Value, ptr llvm.Value, blockPrefix string) { // Check whether we need to emit this check at all. if !ptr.IsAGlobalValue().IsNil() { return } - // Compare against nil. - var isnil llvm.Value - if ptr.Type().PointerAddressSpace() == 0 { - // Do the nil check using the isnil builtin, which marks the parameter - // as nocapture. - // The reason it has to go through a builtin, is that a regular icmp - // instruction may capture the pointer in LLVM semantics, see - // https://reviews.llvm.org/D60047 for details. Pointer capturing - // unfortunately breaks escape analysis, so we use this trick to let the - // functionattr pass know that this pointer doesn't really escape. - ptr = b.CreateBitCast(ptr, b.i8ptrType, "") - isnil = b.createRuntimeCall("isnil", []llvm.Value{ptr}, "") - } else { - // Do the nil check using a regular icmp. This can happen with function - // pointers on AVR, which don't benefit from escape analysis anyway. - nilptr := llvm.ConstPointerNull(ptr.Type()) - isnil = b.CreateICmp(llvm.IntEQ, ptr, nilptr, "") + switch inst := inst.(type) { + case *ssa.IndexAddr: + // This pointer is the result of an index operation into a slice or + // array. Such slices/arrays are already bounds checked so the pointer + // must be a valid (non-nil) pointer. No nil checking is necessary. + return + case *ssa.Convert: + // This is a pointer that comes from a conversion from unsafe.Pointer. + // Don't do nil checking because this is unsafe code and the code should + // know what it is doing. + // Note: all *ssa.Convert instructions that result in a pointer must + // come from unsafe.Pointer. Testing here for unsafe.Pointer to be sure. + if inst.X.Type() == types.Typ[types.UnsafePointer] { + return + } } + // Compare against nil. + // We previously used a hack to make sure this wouldn't break escape + // analysis, but this is not necessary anymore since + // https://reviews.llvm.org/D60047 has been merged. + nilptr := llvm.ConstPointerNull(ptr.Type()) + isnil := b.CreateICmp(llvm.IntEQ, ptr, nilptr, "") + // Emit the nil check in IR. b.createRuntimeAssert(isnil, blockPrefix, "nilPanic") } diff --git a/compiler/calls.go b/compiler/calls.go index 1de28b24ad..e5e1f44c06 100644 --- a/compiler/calls.go +++ b/compiler/calls.go @@ -1,6 +1,8 @@ package compiler import ( + "go/types" + "tinygo.org/x/go-llvm" ) @@ -11,6 +13,16 @@ import ( // a struct contains more fields, it is passed as a struct without expanding. const MaxFieldsPerParam = 3 +// paramFlags identifies parameter attributes for flags. Most importantly, it +// determines which parameters are dereferenceable_or_null and which aren't. +type paramFlags uint8 + +const ( + // Parameter may have the deferenceable_or_null attribute. This attribute + // cannot be applied to unsafe.Pointer and to the data pointer of slices. + paramIsDeferenceableOrNull = 1 << iota +) + // createCall creates a new call to runtime. with the given arguments. func (b *builder) createRuntimeCall(fnName string, args []llvm.Value, name string) llvm.Value { fullName := "runtime." + fnName @@ -36,19 +48,19 @@ func (b *builder) createCall(fn llvm.Value, args []llvm.Value, name string) llvm // Expand an argument type to a list that can be used in a function call // parameter list. -func expandFormalParamType(t llvm.Type) []llvm.Type { +func expandFormalParamType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) { switch t.TypeKind() { case llvm.StructTypeKind: - fields := flattenAggregateType(t) + fields, fieldFlags := flattenAggregateType(t, goType) if len(fields) <= MaxFieldsPerParam { - return fields + return fields, fieldFlags } else { // failed to lower - return []llvm.Type{t} + return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)} } default: // TODO: split small arrays - return []llvm.Type{t} + return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)} } } @@ -79,7 +91,7 @@ func (b *builder) expandFormalParamOffsets(t llvm.Type) []uint64 { func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value { switch v.Type().TypeKind() { case llvm.StructTypeKind: - fieldTypes := flattenAggregateType(v.Type()) + fieldTypes, _ := flattenAggregateType(v.Type(), nil) if len(fieldTypes) <= MaxFieldsPerParam { fields := b.flattenAggregate(v) if len(fields) != len(fieldTypes) { @@ -98,17 +110,62 @@ func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value { // Try to flatten a struct type to a list of types. Returns a 1-element slice // with the passed in type if this is not possible. -func flattenAggregateType(t llvm.Type) []llvm.Type { +func flattenAggregateType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) { + typeFlags := getTypeFlags(goType) switch t.TypeKind() { case llvm.StructTypeKind: fields := make([]llvm.Type, 0, t.StructElementTypesCount()) - for _, subfield := range t.StructElementTypes() { - subfields := flattenAggregateType(subfield) + fieldFlags := make([]paramFlags, 0, cap(fields)) + for i, subfield := range t.StructElementTypes() { + subfields, subfieldFlags := flattenAggregateType(subfield, extractSubfield(goType, i)) + for i := range subfieldFlags { + subfieldFlags[i] |= typeFlags + } fields = append(fields, subfields...) + fieldFlags = append(fieldFlags, subfieldFlags...) } - return fields + return fields, fieldFlags + default: + return []llvm.Type{t}, []paramFlags{typeFlags} + } +} + +// getTypeFlags returns the type flags for a given type. It will not recurse +// into sub-types (such as in structs). +func getTypeFlags(t types.Type) paramFlags { + if t == nil { + return 0 + } + switch t.Underlying().(type) { + case *types.Pointer: + // Pointers in Go must either point to an object or be nil. + return paramIsDeferenceableOrNull + case *types.Chan, *types.Map: + // Channels and maps are implemented as pointers pointing to some + // object, and follow the same rules as *types.Pointer. + return paramIsDeferenceableOrNull + default: + return 0 + } +} + +// extractSubfield extracts a field from a struct, or returns null if this is +// not a struct and thus no subfield can be obtained. +func extractSubfield(t types.Type, field int) types.Type { + if t == nil { + return nil + } + switch t := t.Underlying().(type) { + case *types.Struct: + return t.Field(field).Type() + case *types.Interface, *types.Slice, *types.Basic, *types.Signature: + // These Go types are (sometimes) implemented as LLVM structs but can't + // really be split further up in Go (with the possible exception of + // complex numbers). + return nil default: - return []llvm.Type{t} + // This should be unreachable. + panic("cannot split subfield: " + t.String()) } } @@ -169,7 +226,8 @@ func (b *builder) collapseFormalParam(t llvm.Type, fields []llvm.Value) llvm.Val func (b *builder) collapseFormalParamInternal(t llvm.Type, fields []llvm.Value) (llvm.Value, []llvm.Value) { switch t.TypeKind() { case llvm.StructTypeKind: - if len(flattenAggregateType(t)) <= MaxFieldsPerParam { + flattened, _ := flattenAggregateType(t, nil) + if len(flattened) <= MaxFieldsPerParam { value := llvm.ConstNull(t) for i, subtyp := range t.StructElementTypes() { structField, remaining := b.collapseFormalParamInternal(subtyp, fields) diff --git a/compiler/compiler.go b/compiler/compiler.go index 966b0f62c8..720eb8df5b 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -339,9 +339,6 @@ func Compile(pkgName string, machine llvm.TargetMachine, config *compileopts.Con c.mod.NamedFunction("runtime.alloc").AddAttributeAtIndex(0, getAttr(attrName)) } - // See createNilCheck in asserts.go. - c.mod.NamedFunction("runtime.isnil").AddAttributeAtIndex(1, nocapture) - // On *nix systems, the "abort" functuion in libc is used to handle fatal panics. // Mark it as noreturn so LLVM can optimize away code. if abort := c.mod.NamedFunction("abort"); !abort.IsNil() && abort.IsDeclaration() { @@ -750,10 +747,12 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) { } var paramTypes []llvm.Type + var paramTypeVariants []paramFlags for _, param := range f.Params { paramType := c.getLLVMType(param.Type()) - paramTypeFragments := expandFormalParamType(paramType) + paramTypeFragments, paramTypeFragmentVariants := expandFormalParamType(paramType, param.Type()) paramTypes = append(paramTypes, paramTypeFragments...) + paramTypeVariants = append(paramTypeVariants, paramTypeFragmentVariants...) } // Add an extra parameter as the function context. This context is used in @@ -761,6 +760,7 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) { if !f.IsExported() { paramTypes = append(paramTypes, c.i8ptrType) // context paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine + paramTypeVariants = append(paramTypeVariants, 0, 0) } fnType := llvm.FunctionType(retType, paramTypes, false) @@ -771,6 +771,23 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) { f.LLVMFn = llvm.AddFunction(c.mod, name, fnType) } + dereferenceableOrNullKind := llvm.AttributeKindID("dereferenceable_or_null") + for i, typ := range paramTypes { + if paramTypeVariants[i]¶mIsDeferenceableOrNull == 0 { + continue + } + if typ.TypeKind() == llvm.PointerTypeKind { + el := typ.ElementType() + size := c.targetData.TypeAllocSize(el) + if size == 0 { + // dereferenceable_or_null(0) appears to be illegal in LLVM. + continue + } + dereferenceableOrNull := c.ctx.CreateEnumAttribute(dereferenceableOrNullKind, size) + f.LLVMFn.AddAttributeAtIndex(i+1, dereferenceableOrNull) + } + } + // External/exported functions may not retain pointer values. // https://golang.org/cmd/cgo/#hdr-Passing_pointers if f.IsExported() { @@ -901,7 +918,8 @@ func (b *builder) createFunctionDefinition() { for _, param := range b.fn.Params { llvmType := b.getLLVMType(param.Type()) fields := make([]llvm.Value, 0, 1) - for range expandFormalParamType(llvmType) { + fieldFragments, _ := expandFormalParamType(llvmType, nil) + for range fieldFragments { fields = append(fields, b.fn.LLVMFn.Param(llvmParamIndex)) llvmParamIndex++ } @@ -1133,7 +1151,7 @@ func (b *builder) createInstruction(instr ssa.Instruction) { case *ssa.Store: llvmAddr := b.getValue(instr.Addr) llvmVal := b.getValue(instr.Val) - b.createNilCheck(llvmAddr, "store") + b.createNilCheck(instr.Addr, llvmAddr, "store") if b.targetData.TypeAllocSize(llvmVal.Type()) == 0 { // nothing to store return @@ -1381,7 +1399,7 @@ func (b *builder) createFunctionCall(instr *ssa.CallCommon) (llvm.Value, error) // This is a func value, which cannot be called directly. We have to // extract the function pointer and context first from the func value. callee, context = b.decodeFuncValue(value, instr.Value.Type().Underlying().(*types.Signature)) - b.createNilCheck(callee, "fpcall") + b.createNilCheck(instr.Value, callee, "fpcall") } var params []llvm.Value @@ -1527,7 +1545,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { // > For an operand x of type T, the address operation &x generates a // > pointer of type *T to x. [...] If the evaluation of x would cause a // > run-time panic, then the evaluation of &x does too. - b.createNilCheck(val, "gep") + b.createNilCheck(expr.X, val, "gep") // Do a GEP on the pointer to get the field address. indices := []llvm.Value{ llvm.ConstInt(b.ctx.Int32Type(), 0, false), @@ -1575,7 +1593,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { // > generates a pointer of type *T to x. [...] If the // > evaluation of x would cause a run-time panic, then the // > evaluation of &x does too. - b.createNilCheck(bufptr, "gep") + b.createNilCheck(expr.X, bufptr, "gep") default: return llvm.Value{}, b.makeError(expr.Pos(), "todo: indexaddr: "+typ.String()) } @@ -2567,7 +2585,7 @@ func (b *builder) createUnOp(unop *ssa.UnOp) (llvm.Value, error) { } return b.CreateBitCast(fn, b.i8ptrType, ""), nil } else { - b.createNilCheck(x, "deref") + b.createNilCheck(unop.X, x, "deref") load := b.CreateLoad(x, "") return load, nil } diff --git a/compiler/func.go b/compiler/func.go index 544f3e5fa3..2d14d47a20 100644 --- a/compiler/func.go +++ b/compiler/func.go @@ -125,11 +125,13 @@ func (c *compilerContext) getRawFuncType(typ *types.Signature) llvm.Type { // The receiver is not an interface, but a i8* type. recv = c.i8ptrType } - paramTypes = append(paramTypes, expandFormalParamType(recv)...) + recvFragments, _ := expandFormalParamType(recv, nil) + paramTypes = append(paramTypes, recvFragments...) } for i := 0; i < typ.Params().Len(); i++ { subType := c.getLLVMType(typ.Params().At(i).Type()) - paramTypes = append(paramTypes, expandFormalParamType(subType)...) + paramTypeFragments, _ := expandFormalParamType(subType, nil) + paramTypes = append(paramTypes, paramTypeFragments...) } // All functions take these parameters at the end. paramTypes = append(paramTypes, c.i8ptrType) // context diff --git a/compiler/interface.go b/compiler/interface.go index c6a8a61ac9..4613e2e3ca 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -437,7 +437,7 @@ func (c *compilerContext) getInterfaceInvokeWrapper(f *ir.Function) llvm.Value { // Get the expanded receiver type. receiverType := c.getLLVMType(f.Params[0].Type()) - expandedReceiverType := expandFormalParamType(receiverType) + expandedReceiverType, _ := expandFormalParamType(receiverType, nil) // Does this method even need any wrapping? if len(expandedReceiverType) == 1 && receiverType.TypeKind() == llvm.PointerTypeKind { diff --git a/compiler/volatile.go b/compiler/volatile.go index d2e51c84da..0c94587451 100644 --- a/compiler/volatile.go +++ b/compiler/volatile.go @@ -12,7 +12,7 @@ import ( // runtime/volatile.LoadT(). func (b *builder) createVolatileLoad(instr *ssa.CallCommon) (llvm.Value, error) { addr := b.getValue(instr.Args[0]) - b.createNilCheck(addr, "deref") + b.createNilCheck(instr.Args[0], addr, "deref") val := b.CreateLoad(addr, "") val.SetVolatile(true) return val, nil @@ -23,7 +23,7 @@ func (b *builder) createVolatileLoad(instr *ssa.CallCommon) (llvm.Value, error) func (b *builder) createVolatileStore(instr *ssa.CallCommon) (llvm.Value, error) { addr := b.getValue(instr.Args[0]) val := b.getValue(instr.Args[1]) - b.createNilCheck(addr, "deref") + b.createNilCheck(instr.Args[0], addr, "deref") store := b.CreateStore(val, addr) store.SetVolatile(true) return llvm.Value{}, nil diff --git a/src/runtime/panic.go b/src/runtime/panic.go index 8f8d73662c..2efeed2060 100644 --- a/src/runtime/panic.go +++ b/src/runtime/panic.go @@ -27,14 +27,6 @@ func _recover() interface{} { return nil } -// See emitNilCheck in compiler/asserts.go. -// This function is a dummy function that has its first and only parameter -// marked 'nocapture' to work around a limitation in LLVM: a regular pointer -// comparison captures the pointer. -func isnil(ptr *uint8) bool { - return ptr == nil -} - // Panic when trying to dereference a nil pointer. func nilPanic() { runtimePanic("nil pointer dereference") diff --git a/transform/func-lowering.go b/transform/func-lowering.go index 9687cc7697..ed385b5f73 100644 --- a/transform/func-lowering.go +++ b/transform/func-lowering.go @@ -197,19 +197,8 @@ func LowerFuncValues(mod llvm.Module) { panic("expected inttoptr") } for _, ptrUse := range getUses(callIntPtr) { - if !ptrUse.IsABitCastInst().IsNil() { - for _, bitcastUse := range getUses(ptrUse) { - if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().IsAFunction().IsNil() { - panic("expected a call instruction") - } - switch bitcastUse.CalledValue().Name() { - case "runtime.isnil": - bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false)) - bitcastUse.EraseFromParentAsInstruction() - default: - panic("expected a call to runtime.isnil") - } - } + if !ptrUse.IsAICmpInst().IsNil() { + ptrUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false)) } else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr { addFuncLoweringSwitch(mod, builder, funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { return builder.CreateCall(funcPtr, params, "") diff --git a/transform/optimizer.go b/transform/optimizer.go index 2677e69026..9c77c312da 100644 --- a/transform/optimizer.go +++ b/transform/optimizer.go @@ -98,24 +98,6 @@ func Optimize(mod llvm.Module, config *compileopts.Config, optLevel, sizeLevel i OptimizeAllocs(mod) OptimizeStringToBytes(mod) - // Lower runtime.isnil calls to regular nil comparisons. - isnil := mod.NamedFunction("runtime.isnil") - if !isnil.IsNil() { - builder := mod.Context().NewBuilder() - defer builder.Dispose() - for _, use := range getUses(isnil) { - builder.SetInsertPointBefore(use) - ptr := use.Operand(0) - if !ptr.IsABitCastInst().IsNil() { - ptr = ptr.Operand(0) - } - nilptr := llvm.ConstPointerNull(ptr.Type()) - icmp := builder.CreateICmp(llvm.IntEQ, ptr, nilptr, "") - use.ReplaceAllUsesWith(icmp) - use.EraseFromParentAsInstruction() - } - } - } else { // Must be run at any optimization level. err := LowerInterfaces(mod) diff --git a/transform/testdata/func-lowering.ll b/transform/testdata/func-lowering.ll index b9aea6c031..b5692c7089 100644 --- a/transform/testdata/func-lowering.ll +++ b/transform/testdata/func-lowering.ll @@ -19,8 +19,6 @@ declare void @"internal/task.start"(i32, i8*, i8*, i8*) declare void @runtime.nilPanic(i8*, i8*) -declare i1 @runtime.isnil(i8*, i8*, i8*) - declare void @"main$1"(i32, i8*, i8*) declare void @"main$2"(i32, i8*, i8*) @@ -38,9 +36,8 @@ define void @runFunc1(i8*, i32, i8, i8* %context, i8* %parentHandle) { entry: %3 = call i32 @runtime.getFuncPtr(i8* %0, i32 %1, %runtime.typecodeID* @"reflect/types.type:func:{basic:int8}{}", i8* undef, i8* null) %4 = inttoptr i32 %3 to void (i8, i8*, i8*)* - %5 = bitcast void (i8, i8*, i8*)* %4 to i8* - %6 = call i1 @runtime.isnil(i8* %5, i8* undef, i8* null) - br i1 %6, label %fpcall.nil, label %fpcall.next + %5 = icmp eq void (i8, i8*, i8*)* %4, null + br i1 %5, label %fpcall.nil, label %fpcall.next fpcall.nil: call void @runtime.nilPanic(i8* undef, i8* null) @@ -58,9 +55,8 @@ define void @runFunc2(i8*, i32, i8, i8* %context, i8* %parentHandle) { entry: %3 = call i32 @runtime.getFuncPtr(i8* %0, i32 %1, %runtime.typecodeID* @"reflect/types.type:func:{basic:uint8}{}", i8* undef, i8* null) %4 = inttoptr i32 %3 to void (i8, i8*, i8*)* - %5 = bitcast void (i8, i8*, i8*)* %4 to i8* - %6 = call i1 @runtime.isnil(i8* %5, i8* undef, i8* null) - br i1 %6, label %fpcall.nil, label %fpcall.next + %5 = icmp eq void (i8, i8*, i8*)* %4, null + br i1 %5, label %fpcall.nil, label %fpcall.next fpcall.nil: call void @runtime.nilPanic(i8* undef, i8* null) diff --git a/transform/testdata/func-lowering.out.ll b/transform/testdata/func-lowering.out.ll index af63fdc12c..97621730bd 100644 --- a/transform/testdata/func-lowering.out.ll +++ b/transform/testdata/func-lowering.out.ll @@ -19,8 +19,6 @@ declare void @"internal/task.start"(i32, i8*, i8*, i8*) declare void @runtime.nilPanic(i8*, i8*) -declare i1 @runtime.isnil(i8*, i8*, i8*) - declare void @"main$1"(i32, i8*, i8*) declare void @"main$2"(i32, i8*, i8*) @@ -38,9 +36,8 @@ define void @runFunc1(i8*, i32, i8, i8* %context, i8* %parentHandle) { entry: %3 = icmp eq i32 %1, 0 %4 = select i1 %3, void (i8, i8*, i8*)* null, void (i8, i8*, i8*)* @funcInt8 - %5 = bitcast void (i8, i8*, i8*)* %4 to i8* - %6 = call i1 @runtime.isnil(i8* %5, i8* undef, i8* null) - br i1 %6, label %fpcall.nil, label %fpcall.next + %5 = icmp eq void (i8, i8*, i8*)* %4, null + br i1 %5, label %fpcall.nil, label %fpcall.next fpcall.nil: call void @runtime.nilPanic(i8* undef, i8* null)