Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.pytorch.torchvision.TensorImageUtils;

import java.io.File;
import java.nio.FloatBuffer;
import java.util.Locale;

import androidx.annotation.Nullable;
Expand Down Expand Up @@ -59,6 +60,8 @@ public AnalysisResult(String[] topNClassNames, float[] topNScores,
private TextView mMsText;
private Module mModule;
private String mModuleAssetName;
private FloatBuffer mInputTensorBuffer;
private Tensor mInputTensor;

@Override
protected int getContentViewLayoutId() {
Expand Down Expand Up @@ -136,18 +139,24 @@ protected AnalysisResult analyzeImage(ImageProxy image, int rotationDegrees) {
final String moduleFileAbsoluteFilePath = new File(
Utils.assetFilePath(this, getModuleAssetName())).getAbsolutePath();
mModule = Module.load(moduleFileAbsoluteFilePath);

mInputTensorBuffer =
Tensor.allocateFloatBuffer(3 * INPUT_TENSOR_WIDTH * INPUT_TENSOR_HEIGHT);
mInputTensor = Tensor.newFloat32Tensor(
new long[]{1, 3, INPUT_TENSOR_HEIGHT, INPUT_TENSOR_WIDTH},
mInputTensorBuffer);
}

final long startTime = SystemClock.elapsedRealtime();
final Tensor inputTensor =
TensorImageUtils.imageYUV420CenterCropToFloat32Tensor(
image.getImage(), rotationDegrees,
INPUT_TENSOR_WIDTH, INPUT_TENSOR_HEIGHT,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB);
TensorImageUtils.imageYUV420CenterCropToFloatBuffer(
image.getImage(), rotationDegrees,
INPUT_TENSOR_WIDTH, INPUT_TENSOR_HEIGHT,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
mInputTensorBuffer, 0);

final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final Tensor outputTensor = mModule.forward(IValue.tensor(inputTensor)).getTensor();
final Tensor outputTensor = mModule.forward(IValue.tensor(mInputTensor)).getTensor();
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;

final float[] scores = outputTensor.getDataAsFloatArray();
Expand Down