-
Notifications
You must be signed in to change notification settings - Fork 67
/
MNNClassification.java
118 lines (105 loc) · 3.79 KB
/
MNNClassification.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package com.yeyupiaoling.mnnclassification.mnn;
import android.content.Context;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Matrix;
import android.util.Log;
import org.json.JSONException;
import org.json.JSONObject;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class MNNClassification {
private static final String TAG = MNNClassification.class.getName();
private MNNNetInstance mNetInstance;
private MNNNetInstance.Session mSession;
private MNNNetInstance.Session.Tensor mInputTensor;
private final MNNImageProcess.Config dataConfig;
private Matrix imgData;
private final int inputWidth = 224;
private final int inputHeight = 224;
private static final int NUM_THREADS = 4;
/**
* @param modelPath model path
*/
public MNNClassification(String modelPath) throws Exception {
dataConfig = new MNNImageProcess.Config();
dataConfig.mean = new float[]{128.0f, 128.0f, 128.0f};
dataConfig.normal = new float[]{0.0078125f, 0.0078125f, 0.0078125f};
dataConfig.dest = MNNImageProcess.Format.RGB;
imgData = new Matrix();
File file = new File(modelPath);
if (!file.exists()) {
throw new Exception("model file is not exists!");
}
try {
mNetInstance = MNNNetInstance.createFromFile(modelPath);
MNNNetInstance.Config config = new MNNNetInstance.Config();
config.numThread = NUM_THREADS;
config.forwardType = MNNForwardType.FORWARD_CPU.type;
mSession = mNetInstance.createSession(config);
mInputTensor = mSession.getInput(null);
} catch (Exception e) {
e.printStackTrace();
throw new Exception("load model fail!");
}
}
public float[] predictImage(String image_path) throws Exception {
if (!new File(image_path).exists()) {
throw new Exception("image file is not exists!");
}
FileInputStream fis = new FileInputStream(image_path);
Bitmap bitmap = BitmapFactory.decodeStream(fis);
float[] result = predictImage(bitmap);
if (bitmap.isRecycled()) {
bitmap.recycle();
}
return result;
}
public float[] predictImage(Bitmap bitmap) throws Exception {
return predict(bitmap);
}
// get max probability label
public static int getMaxResult(float[] result) {
float probability = 0;
int r = 0;
for (int i = 0; i < result.length; i++) {
if (probability < result[i]) {
probability = result[i];
r = i;
}
}
return r;
}
// prediction
private float[] predict(Bitmap bmp) throws Exception {
imgData.reset();
imgData.postScale(inputWidth / (float) bmp.getWidth(), inputHeight / (float) bmp.getHeight());
imgData.invert(imgData);
MNNImageProcess.convertBitmap(bmp, mInputTensor, dataConfig, imgData);
try {
mSession.run();
} catch (Exception e) {
throw new Exception("predict image fail! log:" + e);
}
MNNNetInstance.Session.Tensor output = mSession.getOutput(null);
float[] result = output.getFloatData();
Log.d(TAG, Arrays.toString(result));
int l = getMaxResult(result);
return new float[]{l, result[l]};
}
public void release() {
if (mNetInstance != null) {
mNetInstance.release();
mNetInstance = null;
}
}
}