Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFLite UnityPlugin: Add more interpreter functions on Unity Plugin #34450

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,21 @@ public class HelloTFLite : MonoBehaviour {
}

void Start () {
interpreter = new Interpreter(model.bytes);
Debug.LogFormat(
"InputCount: {0}, OutputCount: {1}",
interpreter.GetInputTensorCount(),
interpreter.GetOutputTensorCount());
Debug.LogFormat("TensorFlow Lite Verion: {0}", Interpreter.GetVersion());

var options = new Interpreter.Options() {
threads = 2,
};
interpreter = new Interpreter(model.bytes, options);

int inputCount = interpreter.GetInputTensorCount();
int outputCount = interpreter.GetOutputTensorCount();
for (int i = 0; i < inputCount; i++) {
Debug.LogFormat("Input {0}: {1}", i, interpreter.GetInputTensorInfo(i));
}
for (int i = 0; i < inputCount; i++) {
Debug.LogFormat("Output {0}: {1}", i, interpreter.GetOutputTensorInfo(i));
}
}

void Update () {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
==============================================================================*/
using System;
using System.Runtime.InteropServices;
using System.Linq;

using TfLiteInterpreter = System.IntPtr;
using TfLiteInterpreterOptions = System.IntPtr;
Expand All @@ -27,29 +28,60 @@ namespace TensorFlowLite
/// </summary>
public class Interpreter : IDisposable
{
private const string TensorFlowLibrary = "tensorflowlite_c";
public struct Options: IEquatable<Options> {
/// <summary>
/// The number of CPU threads to use for the interpreter.
/// </summary>
public int threads;

public bool Equals(Options other) {
return threads == other.threads;
}
}

public struct TensorInfo {
public string name { get; internal set; }
public DataType type { get; internal set; }
public int[] dimensions { get; internal set; }
public QuantizationParams quantizationParams { get; internal set; }

public override string ToString() {
return string.Format("name: {0}, type: {1}, dimensions: {2}, quantizationParams: {3}",
name,
type,
"[" + string.Join(",", dimensions.Select(d => d.ToString()).ToArray()) + "]",
"{" + quantizationParams + "}");
}
}

private TfLiteModel model = IntPtr.Zero;
private TfLiteInterpreter interpreter = IntPtr.Zero;
private TfLiteInterpreterOptions options = IntPtr.Zero;

private TfLiteModel model;
private TfLiteInterpreter interpreter;
public Interpreter(byte[] modelData): this(modelData, default(Options)) {}

public Interpreter(byte[] modelData) {
public Interpreter(byte[] modelData, Options options) {
GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned);
IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject();
model = TfLiteModelCreate(modelDataPtr, modelData.Length);
if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model");
interpreter = TfLiteInterpreterCreate(model, /*options=*/IntPtr.Zero);
if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
}

if (!options.Equals(default(Options))) {
this.options = TfLiteInterpreterOptionsCreate();
TfLiteInterpreterOptionsSetNumThreads(this.options, options.threads);
}

~Interpreter() {
Dispose();
interpreter = TfLiteInterpreterCreate(model, this.options);
if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
}

public void Dispose() {
if (interpreter != IntPtr.Zero) TfLiteInterpreterDelete(interpreter);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity, we should delete the interpreter before the model, as the interpreter references the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it

interpreter = IntPtr.Zero;
if (model != IntPtr.Zero) TfLiteModelDelete(model);
model = IntPtr.Zero;
if (options != IntPtr.Zero) TfLiteInterpreterOptionsDelete(options);
options = IntPtr.Zero;
}

public void Invoke() {
Expand Down Expand Up @@ -89,18 +121,98 @@ public class Interpreter : IDisposable
tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData)));
}

public TensorInfo GetInputTensorInfo(int index) {
TfLiteTensor tensor = TfLiteInterpreterGetInputTensor(interpreter, index);
return GetTensorInfo(tensor);
}

public TensorInfo GetOutputTensorInfo(int index) {
TfLiteTensor tensor = TfLiteInterpreterGetOutputTensor(interpreter, index);
return GetTensorInfo(tensor);
}

/// <summary>
/// Returns a string describing version information of the TensorFlow Lite library.
/// TensorFlow Lite uses semantic versioning.
/// </summary>
/// <returns>A string describing version information</returns>
public static string GetVersion() {
jdduke marked this conversation as resolved.
Show resolved Hide resolved
return Marshal.PtrToStringAnsi(TfLiteVersion());
}

private static string GetTensorName(TfLiteTensor tensor) {
return Marshal.PtrToStringAnsi(TfLiteTensorName(tensor));
}

private static TensorInfo GetTensorInfo(TfLiteTensor tensor) {
int[] dimensions = new int[TfLiteTensorNumDims(tensor)];
for (int i = 0; i < dimensions.Length; i++) {
dimensions[i] = TfLiteTensorDim(tensor, i);
}
return new TensorInfo() {
name = GetTensorName(tensor),
type = TfLiteTensorType(tensor),
dimensions = dimensions,
quantizationParams = TfLiteTensorQuantizationParams(tensor),
};
}

private static void ThrowIfError(int resultCode) {
if (resultCode != 0) throw new Exception("TensorFlowLite operation failed.");
}

#region Externs

#if UNITY_IPHONE && !UNITY_EDITOR
private const string TensorFlowLibrary = "__Internal";
#else
private const string TensorFlowLibrary = "tensorflowlite_c";
#endif

public enum DataType {
NoType = 0,
Float32 = 1,
Int32 = 2,
UInt8 = 3,
Int64 = 4,
String = 5,
Bool = 6,
Int16 = 7,
Complex64 = 8,
Int8 = 9,
Float16 = 10,
}

public struct QuantizationParams {
public float scale;
public int zeroPoint;

public override string ToString() {
return string.Format("scale: {0} zeroPoint: {1}", scale, zeroPoint);
}
}

[DllImport (TensorFlowLibrary)]
private static extern unsafe IntPtr TfLiteVersion();

[DllImport (TensorFlowLibrary)]
private static extern unsafe TfLiteInterpreter TfLiteModelCreate(IntPtr model_data, int model_size);

[DllImport (TensorFlowLibrary)]
private static extern unsafe TfLiteInterpreter TfLiteModelDelete(TfLiteModel model);

[DllImport (TensorFlowLibrary)]
private static extern unsafe TfLiteInterpreterOptions TfLiteInterpreterOptionsCreate();

[DllImport (TensorFlowLibrary)]
private static extern unsafe void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions options);

[DllImport (TensorFlowLibrary)]
private static extern unsafe void TfLiteInterpreterOptionsSetNumThreads(
TfLiteInterpreterOptions options,
int num_threads
);

[DllImport (TensorFlowLibrary)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say let's leave this out until we have it wired up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I am implementing GPU delegate. I'll add these in the next PR.
https://github.com/asus4/tf-lite-unity-sample

private static extern unsafe TfLiteInterpreter TfLiteInterpreterCreate(
TfLiteModel model,
Expand Down Expand Up @@ -140,6 +252,24 @@ public class Interpreter : IDisposable
private static extern unsafe TfLiteTensor TfLiteInterpreterGetOutputTensor(
TfLiteInterpreter interpreter,
int output_index);

[DllImport (TensorFlowLibrary)]
private static extern unsafe DataType TfLiteTensorType(TfLiteTensor tensor);

[DllImport (TensorFlowLibrary)]
private static extern unsafe int TfLiteTensorNumDims(TfLiteTensor tensor);

[DllImport (TensorFlowLibrary)]
private static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index);

[DllImport (TensorFlowLibrary)]
private static extern uint TfLiteTensorByteSize(TfLiteTensor tensor);

[DllImport (TensorFlowLibrary)]
private static extern unsafe IntPtr TfLiteTensorName(TfLiteTensor tensor);

[DllImport (TensorFlowLibrary)]
private static extern unsafe QuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor);

[DllImport (TensorFlowLibrary)]
private static extern unsafe int TfLiteTensorCopyFromBuffer(
Expand Down