-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TFLite UnityPlugin: Add more interpreter functions on Unity Plugin #34450
Changes from all commits
5bad210
0178d57
5dfb750
4ab3b90
5ec8c1d
be721b8
0127c02
bc5aa83
cf11e95
1bd2768
9a7508c
5776740
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
==============================================================================*/ | ||
using System; | ||
using System.Runtime.InteropServices; | ||
using System.Linq; | ||
|
||
using TfLiteInterpreter = System.IntPtr; | ||
using TfLiteInterpreterOptions = System.IntPtr; | ||
|
@@ -27,29 +28,60 @@ namespace TensorFlowLite | |
/// </summary> | ||
public class Interpreter : IDisposable | ||
{ | ||
private const string TensorFlowLibrary = "tensorflowlite_c"; | ||
public struct Options: IEquatable<Options> { | ||
/// <summary> | ||
/// The number of CPU threads to use for the interpreter. | ||
/// </summary> | ||
public int threads; | ||
|
||
public bool Equals(Options other) { | ||
return threads == other.threads; | ||
} | ||
} | ||
|
||
public struct TensorInfo { | ||
public string name { get; internal set; } | ||
public DataType type { get; internal set; } | ||
public int[] dimensions { get; internal set; } | ||
public QuantizationParams quantizationParams { get; internal set; } | ||
|
||
public override string ToString() { | ||
return string.Format("name: {0}, type: {1}, dimensions: {2}, quantizationParams: {3}", | ||
name, | ||
type, | ||
"[" + string.Join(",", dimensions.Select(d => d.ToString()).ToArray()) + "]", | ||
"{" + quantizationParams + "}"); | ||
} | ||
} | ||
|
||
private TfLiteModel model = IntPtr.Zero; | ||
private TfLiteInterpreter interpreter = IntPtr.Zero; | ||
private TfLiteInterpreterOptions options = IntPtr.Zero; | ||
|
||
private TfLiteModel model; | ||
private TfLiteInterpreter interpreter; | ||
public Interpreter(byte[] modelData): this(modelData, default(Options)) {} | ||
|
||
public Interpreter(byte[] modelData) { | ||
public Interpreter(byte[] modelData, Options options) { | ||
GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned); | ||
IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject(); | ||
model = TfLiteModelCreate(modelDataPtr, modelData.Length); | ||
if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model"); | ||
interpreter = TfLiteInterpreterCreate(model, /*options=*/IntPtr.Zero); | ||
if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter"); | ||
} | ||
|
||
if (!options.Equals(default(Options))) { | ||
this.options = TfLiteInterpreterOptionsCreate(); | ||
TfLiteInterpreterOptionsSetNumThreads(this.options, options.threads); | ||
} | ||
|
||
~Interpreter() { | ||
Dispose(); | ||
interpreter = TfLiteInterpreterCreate(model, this.options); | ||
if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter"); | ||
} | ||
|
||
public void Dispose() { | ||
if (interpreter != IntPtr.Zero) TfLiteInterpreterDelete(interpreter); | ||
interpreter = IntPtr.Zero; | ||
if (model != IntPtr.Zero) TfLiteModelDelete(model); | ||
model = IntPtr.Zero; | ||
if (options != IntPtr.Zero) TfLiteInterpreterOptionsDelete(options); | ||
options = IntPtr.Zero; | ||
} | ||
|
||
public void Invoke() { | ||
|
@@ -89,18 +121,98 @@ public class Interpreter : IDisposable | |
tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData))); | ||
} | ||
|
||
public TensorInfo GetInputTensorInfo(int index) { | ||
TfLiteTensor tensor = TfLiteInterpreterGetInputTensor(interpreter, index); | ||
return GetTensorInfo(tensor); | ||
} | ||
|
||
public TensorInfo GetOutputTensorInfo(int index) { | ||
TfLiteTensor tensor = TfLiteInterpreterGetOutputTensor(interpreter, index); | ||
return GetTensorInfo(tensor); | ||
} | ||
|
||
/// <summary> | ||
/// Returns a string describing version information of the TensorFlow Lite library. | ||
/// TensorFlow Lite uses semantic versioning. | ||
/// </summary> | ||
/// <returns>A string describing version information</returns> | ||
public static string GetVersion() { | ||
jdduke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Marshal.PtrToStringAnsi(TfLiteVersion()); | ||
} | ||
|
||
private static string GetTensorName(TfLiteTensor tensor) { | ||
return Marshal.PtrToStringAnsi(TfLiteTensorName(tensor)); | ||
} | ||
|
||
private static TensorInfo GetTensorInfo(TfLiteTensor tensor) { | ||
int[] dimensions = new int[TfLiteTensorNumDims(tensor)]; | ||
for (int i = 0; i < dimensions.Length; i++) { | ||
dimensions[i] = TfLiteTensorDim(tensor, i); | ||
} | ||
return new TensorInfo() { | ||
name = GetTensorName(tensor), | ||
type = TfLiteTensorType(tensor), | ||
dimensions = dimensions, | ||
quantizationParams = TfLiteTensorQuantizationParams(tensor), | ||
}; | ||
} | ||
|
||
private static void ThrowIfError(int resultCode) { | ||
if (resultCode != 0) throw new Exception("TensorFlowLite operation failed."); | ||
} | ||
|
||
#region Externs | ||
|
||
#if UNITY_IPHONE && !UNITY_EDITOR | ||
private const string TensorFlowLibrary = "__Internal"; | ||
#else | ||
private const string TensorFlowLibrary = "tensorflowlite_c"; | ||
#endif | ||
|
||
public enum DataType { | ||
NoType = 0, | ||
Float32 = 1, | ||
Int32 = 2, | ||
UInt8 = 3, | ||
Int64 = 4, | ||
String = 5, | ||
Bool = 6, | ||
Int16 = 7, | ||
Complex64 = 8, | ||
Int8 = 9, | ||
Float16 = 10, | ||
} | ||
|
||
public struct QuantizationParams { | ||
public float scale; | ||
public int zeroPoint; | ||
|
||
public override string ToString() { | ||
return string.Format("scale: {0} zeroPoint: {1}", scale, zeroPoint); | ||
} | ||
} | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe IntPtr TfLiteVersion(); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe TfLiteInterpreter TfLiteModelCreate(IntPtr model_data, int model_size); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe TfLiteInterpreter TfLiteModelDelete(TfLiteModel model); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe TfLiteInterpreterOptions TfLiteInterpreterOptionsCreate(); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions options); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe void TfLiteInterpreterOptionsSetNumThreads( | ||
TfLiteInterpreterOptions options, | ||
int num_threads | ||
); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say let's leave this out until we have it wired up? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I am implementing GPU delegate. I'll add these in the next PR. |
||
private static extern unsafe TfLiteInterpreter TfLiteInterpreterCreate( | ||
TfLiteModel model, | ||
|
@@ -140,6 +252,24 @@ public class Interpreter : IDisposable | |
private static extern unsafe TfLiteTensor TfLiteInterpreterGetOutputTensor( | ||
TfLiteInterpreter interpreter, | ||
int output_index); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe DataType TfLiteTensorType(TfLiteTensor tensor); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe int TfLiteTensorNumDims(TfLiteTensor tensor); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern uint TfLiteTensorByteSize(TfLiteTensor tensor); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe IntPtr TfLiteTensorName(TfLiteTensor tensor); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe QuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor); | ||
|
||
[DllImport (TensorFlowLibrary)] | ||
private static extern unsafe int TfLiteTensorCopyFromBuffer( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity, we should delete the interpreter before the model, as the interpreter references the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it