-
Notifications
You must be signed in to change notification settings - Fork 7.4k
/
ClassifierFloatMobileNet.java
70 lines (55 loc) · 2.27 KB
/
ClassifierFloatMobileNet.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.examples.classification.tflite;
import android.app.Activity;
import java.io.IOException;
import org.tensorflow.lite.examples.classification.tflite.Classifier.Device;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
/** This TensorFlowLite classifier works with the float MobileNet model. */
public class ClassifierFloatMobileNet extends Classifier {
/** Float MobileNet requires additional normalization of the used input. */
private static final float IMAGE_MEAN = 127.5f;
private static final float IMAGE_STD = 127.5f;
/**
* Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f
* and 1.0f, repectively, to bypass the normalization.
*/
private static final float PROBABILITY_MEAN = 0.0f;
private static final float PROBABILITY_STD = 1.0f;
/**
* Initializes a {@code ClassifierFloatMobileNet}.
*
* @param activity
*/
public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads)
throws IOException {
super(activity, device, numThreads);
}
// TODO: Specify model.tflite as the model file and labels.txt as the label file
@Override
protected String getModelPath() {
return "model.tflite";
}
@Override
protected String getLabelPath() {
return "labels.txt";
}
@Override
protected TensorOperator getPreprocessNormalizeOp() {
return new NormalizeOp(IMAGE_MEAN, IMAGE_STD);
}
@Override
protected TensorOperator getPostprocessNormalizeOp() {
return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD);
}
}