forked from neo-ai/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Function.java
332 lines (303 loc) · 9.29 KB
/
Function.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tvm;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* TVM Packed Function.
*/
public class Function extends TVMValue {
final long handle;
public final boolean isResident;
private boolean isReleased = false;
/**
* Get registered function.
* @param name full function name.
* @return TVM function.
*/
public static Function getFunction(final String name) {
for (String fullName : listGlobalFuncNames()) {
if (fullName.equals(name)) {
return getGlobalFunc(fullName, true, false);
}
}
return null;
}
/**
* Get list of global functions registered.
* @return List of global functions names.
*/
private static List<String> listGlobalFuncNames() {
List<String> names = new ArrayList<String>();
Base.checkCall(Base._LIB.tvmFuncListGlobalNames(names));
return Collections.unmodifiableList(names);
}
/**
* Get a global function by name.
* @param name The name of the function.
* @param isResident Whether it is a global 'resident' function.
* @param allowMissing Whether allow missing function or raise an error.
* @return The function to be returned, None if function is missing.
*/
private static Function getGlobalFunc(String name, boolean isResident, boolean allowMissing) {
Base.RefLong handle = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncGetGlobal(name, handle));
if (handle.value != 0) {
return new Function(handle.value, isResident);
} else {
if (allowMissing) {
return null;
} else {
throw new IllegalArgumentException("Cannot find global function " + name);
}
}
}
/**
* Initialize the function with handle.
* @param handle the handle to the underlying function.
* @param isResident Whether this is a resident function in jvm
*/
Function(long handle, boolean isResident) {
super(ArgTypeCode.FUNC_HANDLE);
this.handle = handle;
this.isResident = isResident;
}
Function(long handle) {
this(handle, false);
}
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Easy for user to get the instance from returned TVMValue.
* @return this
*/
@Override public Function asFunction() {
return this;
}
@Override long asHandle() {
return handle;
}
/**
* Release the Function.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
@Override public void release() {
if (!isReleased) {
if (!isResident) {
Base.checkCall(Base._LIB.tvmFuncFree(handle));
isReleased = true;
}
}
}
/**
* Invoke the function.
* @return the result.
*/
public TVMValue invoke() {
Base.RefTVMValue ret = new Base.RefTVMValue();
Base.checkCall(Base._LIB.tvmFuncCall(handle, ret));
return ret.value;
}
/**
* Push argument to the function.
* @param arg int argument.
* @return this
*/
public Function pushArg(int arg) {
Base._LIB.tvmFuncPushArgLong(arg);
return this;
}
/**
* Push argument to the function.
* @param arg long argument.
* @return this
*/
public Function pushArg(long arg) {
Base._LIB.tvmFuncPushArgLong(arg);
return this;
}
/**
* Push argument to the function.
* @param arg float argument.
* @return this
*/
public Function pushArg(float arg) {
Base._LIB.tvmFuncPushArgDouble(arg);
return this;
}
/**
* Push argument to the function.
* @param arg double argument.
* @return this
*/
public Function pushArg(double arg) {
Base._LIB.tvmFuncPushArgDouble(arg);
return this;
}
/**
* Push argument to the function.
* @param arg String argument.
* @return this
*/
public Function pushArg(String arg) {
Base._LIB.tvmFuncPushArgString(arg);
return this;
}
/**
* Push argument to the function.
* @param arg NDArray.
* @return this
*/
public Function pushArg(NDArrayBase arg) {
int id = arg.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(arg.handle, id);
return this;
}
/**
* Push argument to the function.
* @param arg Module.
* @return this
*/
public Function pushArg(Module arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.MODULE_HANDLE.id);
return this;
}
/**
* Push argument to the function.
* @param arg Function.
* @return this
*/
public Function pushArg(Function arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.FUNC_HANDLE.id);
return this;
}
/**
* Push argument to the function.
* @param arg bytes.
* @return this
*/
public Function pushArg(byte[] arg) {
Base._LIB.tvmFuncPushArgBytes(arg);
return this;
}
/**
* Invoke function with arguments.
* @param args Can be Integer, Long, Float, Double, String, NDArray.
* @return the result.
*/
public TVMValue call(Object... args) {
for (Object arg : args) {
pushArgToStack(arg);
}
return invoke();
}
private static void pushArgToStack(Object arg) {
if (arg instanceof Integer) {
Base._LIB.tvmFuncPushArgLong((Integer) arg);
} else if (arg instanceof Long) {
Base._LIB.tvmFuncPushArgLong((Long) arg);
} else if (arg instanceof Float) {
Base._LIB.tvmFuncPushArgDouble((Float) arg);
} else if (arg instanceof Double) {
Base._LIB.tvmFuncPushArgDouble((Double) arg);
} else if (arg instanceof String) {
Base._LIB.tvmFuncPushArgString((String) arg);
} else if (arg instanceof byte[]) {
Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
} else if (arg instanceof NDArrayBase) {
NDArrayBase nd = (NDArrayBase) arg;
int id = nd.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(nd.handle, id);
} else if (arg instanceof Module) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) {
Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id);
} else if (arg instanceof TVMValue) {
TVMValue tvmArg = (TVMValue) arg;
switch (tvmArg.typeCode) {
case UINT:
case INT:
Base._LIB.tvmFuncPushArgLong(tvmArg.asLong());
break;
case FLOAT:
Base._LIB.tvmFuncPushArgDouble(tvmArg.asDouble());
break;
case STR:
Base._LIB.tvmFuncPushArgString(tvmArg.asString());
break;
case BYTES:
Base._LIB.tvmFuncPushArgBytes(tvmArg.asBytes());
break;
case HANDLE:
case ARRAY_HANDLE:
case MODULE_HANDLE:
case FUNC_HANDLE:
Base._LIB.tvmFuncPushArgHandle(tvmArg.asHandle(), tvmArg.typeCode.id);
break;
default:
throw new IllegalArgumentException("Invalid argument: " + arg);
}
} else {
throw new IllegalArgumentException("Invalid argument: " + arg);
}
}
public static interface Callback {
public Object invoke(TVMValue... args);
}
/**
* Register user-defined global function.
* @param name The function name.
* @param function The function to be registered.
* @param override Whether override existing entry.
*/
public static void register(String name, Callback function, boolean override) {
Base.RefLong createdFuncHandleRef = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef));
int ioverride = override ? 1 : 0;
Base.checkCall(Base._LIB.tvmFuncRegisterGlobal(name, createdFuncHandleRef.value, ioverride));
}
/**
* Register user-defined global function, do not override existing entry.
* @param name The function name.
* @param function The function to be registered.
*/
public static void register(String name, Callback function) {
register(name, function, false);
}
/**
* Convert a Java function to TVM function.
* @param function Java function.
* @return TVM function.
*/
public static Function convertFunc(Callback function) {
Base.RefLong createdFuncHandleRef = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef));
return new Function(createdFuncHandleRef.value);
}
private static Object invokeRegisteredCbFunc(Callback cb, TVMValue[] args) {
if (cb == null) {
System.err.println("[ERROR] Failed to get registered function");
return null;
}
return cb.invoke(args);
}
}