/
ImageSegmentationHelper.kt
163 lines (143 loc) · 5.69 KB
/
ImageSegmentationHelper.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
/*
* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.lite.examples.imagesegmentation
import android.content.Context
import android.graphics.Bitmap
import android.os.Build
import android.os.SystemClock
import android.util.Log
import androidx.annotation.RequiresApi
import androidx.core.graphics.get
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.Rot90Op
import org.tensorflow.lite.task.core.BaseOptions
import org.tensorflow.lite.task.vision.segmenter.ImageSegmenter
import org.tensorflow.lite.task.vision.segmenter.OutputType
import org.tensorflow.lite.task.vision.segmenter.Segmentation
import java.lang.Exception
import java.util.*
/**
* Class responsible to run the Image Segmentation model. more information about the DeepLab model
* being used can be found here:
* https://ai.googleblog.com/2018/03/semantic-image-segmentation-with.html
* https://github.com/tensorflow/models/tree/master/research/deeplab
*
* Label names: 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
* 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
* 'sofa', 'train', 'tv'
*/
class ImageSegmentationHelper(
var numThreads: Int = 2,
var currentDelegate: Int = 0,
val context: Context,
val imageSegmentationListener: SegmentationListener?
) {
private var imageSegmenter: ImageSegmenter? = null
init {
setupImageSegmenter()
}
fun clearImageSegmenter() {
imageSegmenter = null
}
private fun setupImageSegmenter() {
// Create the base options for the segment
val optionsBuilder =
ImageSegmenter.ImageSegmenterOptions.builder()
// Set general segmentation options, including number of used threads
val baseOptionsBuilder = BaseOptions.builder().setNumThreads(numThreads)
// Use the specified hardware for running the model. Default to CPU
when (currentDelegate) {
DELEGATE_CPU -> {
// Default
}
DELEGATE_GPU -> {
if (CompatibilityList().isDelegateSupportedOnThisDevice) {
baseOptionsBuilder.useGpu()
} else {
imageSegmentationListener?.onError("GPU is not supported on this device")
}
}
DELEGATE_NNAPI -> {
baseOptionsBuilder.useNnapi()
}
}
optionsBuilder.setBaseOptions(baseOptionsBuilder.build())
/*
CATEGORY_MASK is being specifically used to predict the available objects
based on individual pixels in this sample. The other option available for
OutputType, CONFIDENCE_MAP, provides a gray scale mapping of the image
where each pixel has a confidence score applied to it from 0.0f to 1.0f
*/
optionsBuilder.setOutputType(OutputType.CATEGORY_MASK)
try {
imageSegmenter =
ImageSegmenter.createFromFileAndOptions(
context,
MODEL_DEEPLABV3,
optionsBuilder.build()
)
} catch (e: IllegalStateException) {
imageSegmentationListener?.onError(
"Image segmentation failed to initialize. See error logs for details"
)
Log.e(TAG, "TFLite failed to load model with error: " + e.message)
}
}
@RequiresApi(Build.VERSION_CODES.Q)
fun segment(image: Bitmap, imageRotation: Int) {
if (imageSegmenter == null) {
setupImageSegmenter()
}
// Inference time is the difference between the system time at the start and finish of the
// process
var inferenceTime = SystemClock.uptimeMillis()
// Create preprocessor for the image.
// See https://www.tensorflow.org/lite/inference_with_metadata/
// lite_support#imageprocessor_architecture
val imageProcessor =
ImageProcessor.Builder()
.add(Rot90Op(-imageRotation / 90))
.build()
// Preprocess the image and convert it into a TensorImage for segmentation.
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(image))
val segmentResult = imageSegmenter?.segment(tensorImage)
inferenceTime = SystemClock.uptimeMillis() - inferenceTime
imageSegmentationListener?.onResults(
segmentResult,
inferenceTime,
tensorImage.height,
tensorImage.width
)
}
interface SegmentationListener {
fun onError(error: String)
fun onResults(
results: List<Segmentation>?,
inferenceTime: Long,
imageHeight: Int,
imageWidth: Int
)
}
companion object {
const val DELEGATE_CPU = 0
const val DELEGATE_GPU = 1
const val DELEGATE_NNAPI = 2
const val MODEL_DEEPLABV3 = "deeplabv3.tflite"
private const val TAG = "Image Segmentation Helper"
}
}