-
Notifications
You must be signed in to change notification settings - Fork 124
/
ImageSegmenter.java
447 lines (408 loc) · 18.8 KB
/
ImageSegmenter.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.task.vision.segmenter;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.os.ParcelFileDescriptor;
import com.google.android.odml.image.MlImage;
import com.google.auto.value.AutoValue;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.tensorflow.lite.support.image.MlImageAdapter;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.task.core.TaskJniUtils;
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
/**
* Performs segmentation on images.
*
* <p>The API expects a TFLite model with <a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
*
* <p>The API supports models with one image input tensor and one output tensor. To be more
* specific, here are the requirements.
*
* <ul>
* <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
* <ul>
* <li>image input of size {@code [batch x height x width x channels]}.
* <li>batch inference is not supported ({@code batch} is required to be 1).
* <li>only RGB inputs are supported ({@code channels} is required to be 3).
* <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
* to the metadata for input normalization.
* </ul>
* <li>Output image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
* <ul>
* <li>tensor of size {@code [batch x mask_height x mask_width x num_classes]}, where {@code
* batch} is required to be 1, {@code mask_width} and {@code mask_height} are the
* dimensions of the segmentation masks produced by the model, and {@code num_classes}
* is the number of classes supported by the model.
* <li>optional (but recommended) label map(s) can be attached as AssociatedFile-s with type
* TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if
* any) is used to fill the class name, i.e. {@link ColoredLabel#getlabel} of the
* results. The display name, i.e. {@link ColoredLabel#getDisplayName}, is filled from
* the AssociatedFile (if any) whose locale matches the `display_names_locale` field of
* the `ImageSegmenterOptions` used at creation time ("en" by default, i.e. English). If
* none of these are available, only the `index` field of the results will be filled.
* </ul>
* </ul>
*
* <p>An example of such model can be found on <a
* href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>.
*/
public final class ImageSegmenter extends BaseVisionTaskApi {
private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni";
private static final int OPTIONAL_FD_LENGTH = -1;
private static final int OPTIONAL_FD_OFFSET = -1;
private final OutputType outputType;
/**
* Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
*
* @param modelPath path of the segmentation model with metadata in the assets
* @throws IOException if an I/O error occurs when loading the tflite model
* @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
* code
*/
public static ImageSegmenter createFromFile(Context context, String modelPath)
throws IOException {
return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build());
}
/**
* Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
*
* @param modelFile the segmentation model {@link File} instance
* @throws IOException if an I/O error occurs when loading the tflite model
* @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
* code
*/
public static ImageSegmenter createFromFile(File modelFile) throws IOException {
return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build());
}
/**
* Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link
* ImageSegmenterOptions}.
*
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
* segmentation model
* @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
* code
* @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
* {@link MappedByteBuffer}
*/
public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) {
return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build());
}
/**
* Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
*
* @param modelPath path of the segmentation model with metadata in the assets
* @throws IOException if an I/O error occurs when loading the tflite model
* @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
* code
*/
public static ImageSegmenter createFromFileAndOptions(
Context context, String modelPath, final ImageSegmenterOptions options) throws IOException {
try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
return createFromModelFdAndOptions(
/*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
/*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
/*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
options);
}
}
/**
* Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
*
* @param modelFile the segmentation model {@link File} instance
* @throws IOException if an I/O error occurs when loading the tflite model
* @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
* code
*/
public static ImageSegmenter createFromFileAndOptions(
File modelFile, final ImageSegmenterOptions options) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
return createFromModelFdAndOptions(
/*fileDescriptor=*/ descriptor.getFd(),
/*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
/*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
options);
}
}
/**
* Creates an {@link ImageSegmenter} instance with a model buffer and {@link
* ImageSegmenterOptions}.
*
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
* segmentation model
* @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
* code
* @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
* {@link MappedByteBuffer}
*/
public static ImageSegmenter createFromBufferAndOptions(
final ByteBuffer modelBuffer, final ImageSegmenterOptions options) {
if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
throw new IllegalArgumentException(
"The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
}
return new ImageSegmenter(
TaskJniUtils.createHandleFromLibrary(
new EmptyHandleProvider() {
@Override
public long createHandle() {
return initJniWithByteBuffer(
modelBuffer,
options.getDisplayNamesLocale(),
options.getOutputType().getValue(),
options.getNumThreads());
}
},
IMAGE_SEGMENTER_NATIVE_LIB),
options.getOutputType());
}
/**
* Constructor to initialize the JNI with a pointer from C++.
*
* @param nativeHandle a pointer referencing memory allocated in C++
*/
private ImageSegmenter(long nativeHandle, OutputType outputType) {
super(nativeHandle);
this.outputType = outputType;
}
/** Options for setting up an {@link ImageSegmenter}. */
@AutoValue
public abstract static class ImageSegmenterOptions {
private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en";
private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
private static final int NUM_THREADS = -1;
public abstract String getDisplayNamesLocale();
public abstract OutputType getOutputType();
public abstract int getNumThreads();
public static Builder builder() {
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
.setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
.setOutputType(DEFAULT_OUTPUT_TYPE)
.setNumThreads(NUM_THREADS);
}
/** Builder for {@link ImageSegmenterOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/**
* Sets the locale to use for display names specified through the TFLite Model Metadata, if
* any.
*
* <p>Defaults to English({@code "en"}). See the <a
* href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
* Metadata schema file.</a> for the accepted pattern of locale.
*/
public abstract Builder setDisplayNamesLocale(String displayNamesLocale);
public abstract Builder setOutputType(OutputType outputType);
/**
* Sets the number of threads to be used for TFLite ops that support multi-threading when
* running inference with CPU. Defaults to -1.
*
* <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
* effect to let TFLite runtime set the value.
*/
public abstract Builder setNumThreads(int numThreads);
public abstract ImageSegmenterOptions build();
}
}
/**
* Performs actual segmentation on the provided image.
*
* <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
*
* <ul>
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
* </ul>
*
* @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
* @return results of performing image segmentation. Note that at the time, a single {@link
* Segmentation} element is expected to be returned. The result is stored in a {@link List}
* for later extension to e.g. instance segmentation models, which may return one segmentation
* per object.
* @throws AssertionError if error occurs when segmenting the image from the native code
* @throws IllegalArgumentException if the color space type of image is unsupported
*/
public List<Segmentation> segment(TensorImage image) {
return segment(image, ImageProcessingOptions.builder().build());
}
/**
* Performs actual segmentation on the provided image with {@link ImageProcessingOptions}.
*
* <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
*
* <ul>
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
* </ul>
*
* <p>{@link ImageSegmenter} supports the following options:
*
* <ul>
* <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
* defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}
* </ul>
*
* @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
* @param options the options configure how to preprocess the image
* @return results of performing image segmentation. Note that at the time, a single {@link
* Segmentation} element is expected to be returned. The result is stored in a {@link List}
* for later extension to e.g. instance segmentation models, which may return one segmentation
* per object.
* @throws AssertionError if error occurs when segmenting the image from the native code
* @throws IllegalArgumentException if the color space type of image is unsupported
*/
public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) {
return run(
new InferenceProvider<List<Segmentation>>() {
@Override
public List<Segmentation> run(
long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
return segment(frameBufferHandle, options);
}
},
image,
options);
}
/**
* Performs actual segmentation on the provided {@code MlImage}.
*
* @param image an {@code MlImage} to segment.
* @return results of performing image segmentation. Note that at the time, a single {@link
* Segmentation} element is expected to be returned. The result is stored in a {@link List}
* for later extension to e.g. instance segmentation models, which may return one segmentation
* per object.
* @throws AssertionError if error occurs when segmenting the image from the native code
* @throws IllegalArgumentException if the storage type or format of the image is unsupported
*/
public List<Segmentation> segment(MlImage image) {
return segment(image, ImageProcessingOptions.builder().build());
}
/**
* Performs actual segmentation on the provided {@code MlImage} with {@link
* ImageProcessingOptions}.
*
* <p>{@link ImageSegmenter} supports the following options:
*
* <ul>
* <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
* defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
* MlImage#getRotation()} is not effective.
* </ul>
*
* @param image an {@code MlImage} to segment.
* @param options the options configure how to preprocess the image.
* @return results of performing image segmentation. Note that at the time, a single {@link
* Segmentation} element is expected to be returned. The result is stored in a {@link List}
* for later extension to e.g. instance segmentation models, which may return one segmentation
* per object.
* @throws AssertionError if error occurs when segmenting the image from the native code
* @throws IllegalArgumentException if the color space type of image is unsupported
*/
public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) {
image.getInternal().acquire();
TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
List<Segmentation> result = segment(tensorImage, options);
image.close();
return result;
}
public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) {
checkNotClosed();
List<byte[]> maskByteArrays = new ArrayList<>();
List<ColoredLabel> coloredLabels = new ArrayList<>();
int[] maskShape = new int[2];
segmentNative(getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels);
List<ByteBuffer> maskByteBuffers = new ArrayList<>();
for (byte[] bytes : maskByteArrays) {
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
// Change the byte order to little_endian, since the buffers were generated in jni.
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
maskByteBuffers.add(byteBuffer);
}
return Arrays.asList(
Segmentation.create(
outputType,
outputType.createMasksFromBuffer(maskByteBuffers, maskShape),
coloredLabels));
}
private static ImageSegmenter createFromModelFdAndOptions(
final int fileDescriptor,
final long fileDescriptorLength,
final long fileDescriptorOffset,
final ImageSegmenterOptions options) {
long nativeHandle =
TaskJniUtils.createHandleFromLibrary(
new EmptyHandleProvider() {
@Override
public long createHandle() {
return initJniWithModelFdAndOptions(
fileDescriptor,
fileDescriptorLength,
fileDescriptorOffset,
options.getDisplayNamesLocale(),
options.getOutputType().getValue(),
options.getNumThreads());
}
},
IMAGE_SEGMENTER_NATIVE_LIB);
return new ImageSegmenter(nativeHandle, options.getOutputType());
}
private static native long initJniWithModelFdAndOptions(
int fileDescriptor,
long fileDescriptorLength,
long fileDescriptorOffset,
String displayNamesLocale,
int outputType,
int numThreads);
private static native long initJniWithByteBuffer(
ByteBuffer modelBuffer, String displayNamesLocale, int outputType, int numThreads);
/**
* The native method to segment the image.
*
* <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native
* layer.
*/
private static native void segmentNative(
long nativeHandle,
long frameBufferHandle,
List<byte[]> maskByteArrays,
int[] maskShape,
List<ColoredLabel> coloredLabels);
@Override
protected void deinit(long nativeHandle) {
deinitJni(nativeHandle);
}
/**
* Native implementation to release memory pointed by the pointer.
*
* @param nativeHandle pointer to memory allocated
*/
private native void deinitJni(long nativeHandle);
}