forked from tinygo-org/tinygo
/
func.go
265 lines (242 loc) · 9.39 KB
/
func.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
package compiler
// This file implements function values and closures. It may need some lowering
// in a later step, see func-lowering.go.
import (
"go/types"
"golang.org/x/tools/go/ssa"
"tinygo.org/x/go-llvm"
)
type funcValueImplementation int
const (
funcValueNone funcValueImplementation = iota
// A func value is implemented as a pair of pointers:
// {context, function pointer}
// where the context may be a pointer to a heap-allocated struct containing
// the free variables, or it may be undef if the function being pointed to
// doesn't need a context. The function pointer is a regular function
// pointer.
funcValueDoubleword
// As funcValueDoubleword, but with the function pointer replaced by a
// unique ID per function signature. Function values are called by using a
// switch statement and choosing which function to call.
funcValueSwitch
)
// funcImplementation picks an appropriate func value implementation for the
// target.
func (c *Compiler) funcImplementation() funcValueImplementation {
if c.GOARCH == "wasm" {
return funcValueSwitch
} else {
return funcValueDoubleword
}
}
// createFuncValue creates a function value from a raw function pointer with no
// context.
func (c *Compiler) createFuncValue(funcPtr, context llvm.Value, sig *types.Signature) (llvm.Value, error) {
var funcValueScalar llvm.Value
switch c.funcImplementation() {
case funcValueDoubleword:
// Closure is: {context, function pointer}
funcValueScalar = funcPtr
case funcValueSwitch:
sigGlobal := c.getFuncSignature(sig)
funcValueWithSignatureGlobalName := funcPtr.Name() + "$withSignature"
funcValueWithSignatureGlobal := c.mod.NamedGlobal(funcValueWithSignatureGlobalName)
if funcValueWithSignatureGlobal.IsNil() {
funcValueWithSignatureType := c.mod.GetTypeByName("runtime.funcValueWithSignature")
funcValueWithSignature := llvm.ConstNamedStruct(funcValueWithSignatureType, []llvm.Value{
llvm.ConstPtrToInt(funcPtr, c.uintptrType),
sigGlobal,
})
funcValueWithSignatureGlobal = llvm.AddGlobal(c.mod, funcValueWithSignatureType, funcValueWithSignatureGlobalName)
funcValueWithSignatureGlobal.SetInitializer(funcValueWithSignature)
funcValueWithSignatureGlobal.SetGlobalConstant(true)
funcValueWithSignatureGlobal.SetLinkage(llvm.InternalLinkage)
}
funcValueScalar = llvm.ConstPtrToInt(funcValueWithSignatureGlobal, c.uintptrType)
default:
panic("unimplemented func value variant")
}
funcValueType, err := c.getFuncType(sig)
if err != nil {
return llvm.Value{}, err
}
funcValue := llvm.Undef(funcValueType)
funcValue = c.builder.CreateInsertValue(funcValue, context, 0, "")
funcValue = c.builder.CreateInsertValue(funcValue, funcValueScalar, 1, "")
return funcValue, nil
}
// getFuncSignature returns a global for identification of a particular function
// signature. It is used in runtime.funcValueWithSignature and in calls to
// getFuncPtr.
func (c *Compiler) getFuncSignature(sig *types.Signature) llvm.Value {
typeCodeName := getTypeCodeName(sig)
sigGlobalName := "reflect/types.type:" + typeCodeName
sigGlobal := c.mod.NamedGlobal(sigGlobalName)
if sigGlobal.IsNil() {
sigGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), sigGlobalName)
sigGlobal.SetInitializer(llvm.Undef(c.ctx.Int8Type()))
sigGlobal.SetGlobalConstant(true)
sigGlobal.SetLinkage(llvm.InternalLinkage)
}
return sigGlobal
}
// extractFuncScalar returns some scalar that can be used in comparisons. It is
// a cheap operation.
func (c *Compiler) extractFuncScalar(funcValue llvm.Value) llvm.Value {
return c.builder.CreateExtractValue(funcValue, 1, "")
}
// extractFuncContext extracts the context pointer from this function value. It
// is a cheap operation.
func (c *Compiler) extractFuncContext(funcValue llvm.Value) llvm.Value {
return c.builder.CreateExtractValue(funcValue, 0, "")
}
// decodeFuncValue extracts the context and the function pointer from this func
// value. This may be an expensive operation.
func (c *Compiler) decodeFuncValue(funcValue llvm.Value, sig *types.Signature) (funcPtr, context llvm.Value, err error) {
context = c.builder.CreateExtractValue(funcValue, 0, "")
switch c.funcImplementation() {
case funcValueDoubleword:
funcPtr = c.builder.CreateExtractValue(funcValue, 1, "")
case funcValueSwitch:
llvmSig, err := c.getRawFuncType(sig)
if err != nil {
return llvm.Value{}, llvm.Value{}, err
}
sigGlobal := c.getFuncSignature(sig)
funcPtr = c.createRuntimeCall("getFuncPtr", []llvm.Value{funcValue, sigGlobal}, "")
funcPtr = c.builder.CreateIntToPtr(funcPtr, llvmSig, "")
default:
panic("unimplemented func value variant")
}
return
}
// getFuncType returns the type of a func value given a signature.
func (c *Compiler) getFuncType(typ *types.Signature) (llvm.Type, error) {
switch c.funcImplementation() {
case funcValueDoubleword:
rawPtr, err := c.getRawFuncType(typ)
if err != nil {
return llvm.Type{}, err
}
return c.ctx.StructType([]llvm.Type{c.i8ptrType, rawPtr}, false), nil
case funcValueSwitch:
return c.mod.GetTypeByName("runtime.funcValue"), nil
default:
panic("unimplemented func value variant")
}
}
// getRawFuncType returns a LLVM function pointer type for a given signature.
func (c *Compiler) getRawFuncType(typ *types.Signature) (llvm.Type, error) {
// Get the return type.
var err error
var returnType llvm.Type
switch typ.Results().Len() {
case 0:
// No return values.
returnType = c.ctx.VoidType()
case 1:
// Just one return value.
returnType, err = c.getLLVMType(typ.Results().At(0).Type())
if err != nil {
return llvm.Type{}, err
}
default:
// Multiple return values. Put them together in a struct.
// This appears to be the common way to handle multiple return values in
// LLVM.
members := make([]llvm.Type, typ.Results().Len())
for i := 0; i < typ.Results().Len(); i++ {
returnType, err := c.getLLVMType(typ.Results().At(i).Type())
if err != nil {
return llvm.Type{}, err
}
members[i] = returnType
}
returnType = c.ctx.StructType(members, false)
}
// Get the parameter types.
var paramTypes []llvm.Type
if typ.Recv() != nil {
recv, err := c.getLLVMType(typ.Recv().Type())
if err != nil {
return llvm.Type{}, err
}
if recv.StructName() == "runtime._interface" {
// This is a call on an interface, not a concrete type.
// The receiver is not an interface, but a i8* type.
recv = c.i8ptrType
}
paramTypes = append(paramTypes, c.expandFormalParamType(recv)...)
}
for i := 0; i < typ.Params().Len(); i++ {
subType, err := c.getLLVMType(typ.Params().At(i).Type())
if err != nil {
return llvm.Type{}, err
}
paramTypes = append(paramTypes, c.expandFormalParamType(subType)...)
}
// All functions take these parameters at the end.
paramTypes = append(paramTypes, c.i8ptrType) // context
paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine
// Make a func type out of the signature.
return llvm.PointerType(llvm.FunctionType(returnType, paramTypes, false), c.funcPtrAddrSpace), nil
}
// parseMakeClosure makes a function value (with context) from the given
// closure expression.
func (c *Compiler) parseMakeClosure(frame *Frame, expr *ssa.MakeClosure) (llvm.Value, error) {
if len(expr.Bindings) == 0 {
panic("unexpected: MakeClosure without bound variables")
}
f := c.ir.GetFunction(expr.Fn.(*ssa.Function))
// Collect all bound variables.
boundVars := make([]llvm.Value, 0, len(expr.Bindings))
boundVarTypes := make([]llvm.Type, 0, len(expr.Bindings))
for _, binding := range expr.Bindings {
// The context stores the bound variables.
llvmBoundVar, err := c.parseExpr(frame, binding)
if err != nil {
return llvm.Value{}, err
}
boundVars = append(boundVars, llvmBoundVar)
boundVarTypes = append(boundVarTypes, llvmBoundVar.Type())
}
contextType := c.ctx.StructType(boundVarTypes, false)
// Allocate memory for the context.
contextAlloc := llvm.Value{}
contextHeapAlloc := llvm.Value{}
if c.targetData.TypeAllocSize(contextType) <= c.targetData.TypeAllocSize(c.i8ptrType) {
// Context fits in a pointer - e.g. when it is a pointer. Store it
// directly in the stack after a convert.
// Because contextType is a struct and we have to cast it to a *i8,
// store it in an alloca first for bitcasting (store+bitcast+load).
contextAlloc = c.builder.CreateAlloca(contextType, "")
} else {
// Context is bigger than a pointer, so allocate it on the heap.
size := c.targetData.TypeAllocSize(contextType)
sizeValue := llvm.ConstInt(c.uintptrType, size, false)
contextHeapAlloc = c.createRuntimeCall("alloc", []llvm.Value{sizeValue}, "")
contextAlloc = c.builder.CreateBitCast(contextHeapAlloc, llvm.PointerType(contextType, 0), "")
}
// Store all bound variables in the alloca or heap pointer.
for i, boundVar := range boundVars {
indices := []llvm.Value{
llvm.ConstInt(c.ctx.Int32Type(), 0, false),
llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false),
}
gep := c.builder.CreateInBoundsGEP(contextAlloc, indices, "")
c.builder.CreateStore(boundVar, gep)
}
context := llvm.Value{}
if c.targetData.TypeAllocSize(contextType) <= c.targetData.TypeAllocSize(c.i8ptrType) {
// Load value (as *i8) from the alloca.
contextAlloc = c.builder.CreateBitCast(contextAlloc, llvm.PointerType(c.i8ptrType, 0), "")
context = c.builder.CreateLoad(contextAlloc, "")
} else {
// Get the original heap allocation pointer, which already is an
// *i8.
context = contextHeapAlloc
}
// Create the closure.
return c.createFuncValue(f.LLVMFn, context, f.Signature)
}