Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;

import androidx.annotation.Nullable;
import androidx.annotation.UiThread;
import androidx.annotation.WorkerThread;
Expand All @@ -32,7 +31,8 @@
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;

import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
Expand All @@ -43,14 +43,8 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;

import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;

public class CameraActivity extends AppCompatActivity {
Expand Down Expand Up @@ -143,9 +137,9 @@ public void onRequestPermissionsResult(
if (requestCode == REQUEST_CODE_CAMERA_PERMISSION) {
if (grantResults[0] == PackageManager.PERMISSION_DENIED) {
Toast.makeText(
this,
"You can't use image classification example without granting CAMERA permission",
Toast.LENGTH_LONG)
this,
"You can't use image classification example without granting CAMERA permission",
Toast.LENGTH_LONG)
.show();
finish();
} else {
Expand Down Expand Up @@ -217,9 +211,7 @@ private static int clamp0255(int x) {
}

protected void fillInputTensorBuffer(
ImageProxy image,
int rotationDegrees,
FloatBuffer inputTensorBuffer) {
ImageProxy image, int rotationDegrees, FloatBuffer inputTensorBuffer) {

if (mInputTensorBitmap == null) {
final int tensorSize = Math.min(image.getWidth(), image.getHeight());
Expand Down Expand Up @@ -305,7 +297,8 @@ protected void fillInputTensorBuffer(
gi = gi > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (gi < 0 ? 0 : gi);
bi = bi > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (bi < 0 ? 0 : bi);

final int color = 0xff000000 | ((ri << 6) & 0xff0000) | ((gi >> 2) & 0xff00) | ((bi >> 10) & 0xff);
final int color =
0xff000000 | ((ri << 6) & 0xff0000) | ((gi >> 2) & 0xff00) | ((bi >> 10) & 0xff);
mInputTensorBitmap.setPixel(x, y, color);
inputTensorBuffer.put(0 * channelSize + y * tensorSize + x, clamp0255(ri >> 10) / 255.f);
inputTensorBuffer.put(1 * channelSize + y * tensorSize + x, clamp0255(gi >> 10) / 255.f);
Expand Down Expand Up @@ -344,7 +337,7 @@ protected Result analyzeImage(ImageProxy image, int rotationDegrees) {
if (mModule == null) {
Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'");
mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * tensorSize * tensorSize);
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{3, tensorSize, tensorSize});
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[] {3, tensorSize, tensorSize});
final String modelFileAbsoluteFilePath =
new File(assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath();
mModule = Module.load(modelFileAbsoluteFilePath);
Expand All @@ -358,8 +351,8 @@ protected Result analyzeImage(ImageProxy image, int rotationDegrees) {
final IValue out1 = outputTuple.toTuple()[1];
final Map<String, IValue> map = out1.toList()[0].toDictStringKey();

float[] boxesData = new float[]{};
float[] scoresData = new float[]{};
float[] boxesData = new float[] {};
float[] scoresData = new float[] {};
final List<BBox> bboxes = new ArrayList<>();
if (map.containsKey("boxes")) {
final Tensor boxesTensor = map.get("boxes").toTensor();
Expand All @@ -368,13 +361,13 @@ protected Result analyzeImage(ImageProxy image, int rotationDegrees) {
scoresData = scoresTensor.getDataAsFloatArray();
final int n = scoresData.length;
for (int i = 0; i < n; i++) {
final BBox bbox = new BBox(
scoresData[i],
boxesData[4 * i + 0],
boxesData[4 * i + 1],
boxesData[4 * i + 2],
boxesData[4 * i + 3]
);
final BBox bbox =
new BBox(
scoresData[i],
boxesData[4 * i + 0],
boxesData[4 * i + 1],
boxesData[4 * i + 2],
boxesData[4 * i + 3]);
android.util.Log.i(TAG, String.format("Forward result %d: %s", i, bbox));
bboxes.add(bbox);
}
Expand Down Expand Up @@ -407,8 +400,7 @@ protected void handleResult(Result result) {
mInputTensorBitmap,
new Rect(0, 0, result.tensorSize, result.tensorSize),
new Rect(offsetX, offsetY, offsetX + size, offsetY + size),
null
);
null);

for (final BBox bbox : result.bboxes) {
if (bbox.score < BBOX_SCORE_DRAW_THRESHOLD) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
import androidx.appcompat.app.AppCompatActivity;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.nio.FloatBuffer;
import java.util.Map;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;

import java.nio.FloatBuffer;
import java.util.Map;

public class MainActivity extends AppCompatActivity {
static {
if (!NativeLoader.isInitialized()) {
Expand Down Expand Up @@ -116,8 +115,10 @@ protected Result doModuleForward() {
final float[] scoresData = scores.getDataAsFloatArray();
final int n = scoresData.length;
for (int i = 0; i < n; i++) {
android.util.Log.i(TAG,
String.format("Forward result %d: score %f box:(%f, %f, %f, %f)",
android.util.Log.i(
TAG,
String.format(
"Forward result %d: score %f box:(%f, %f, %f, %f)",
scoresData[i],
boxesData[4 * i + 0],
boxesData[4 * i + 1],
Expand All @@ -130,7 +131,7 @@ protected Result doModuleForward() {

final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
return new Result(new float[]{}, moduleForwardDuration, analysisDuration);
return new Result(new float[] {}, moduleForwardDuration, analysisDuration);
}

static class Result {
Expand Down