From 20fd66bd0c8d68c45fadbdb6d349260e2b82d243 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Thu, 3 Oct 2019 21:40:31 -0700 Subject: [PATCH] [demo] Reuse input tensor buffer --- .../vision/ImageClassificationActivity.java | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/ImageClassificationActivity.java b/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/ImageClassificationActivity.java index 5c9367f2..da9b9eac 100644 --- a/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/ImageClassificationActivity.java +++ b/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/ImageClassificationActivity.java @@ -19,6 +19,7 @@ import org.pytorch.torchvision.TensorImageUtils; import java.io.File; +import java.nio.FloatBuffer; import java.util.Locale; import androidx.annotation.Nullable; @@ -59,6 +60,8 @@ public AnalysisResult(String[] topNClassNames, float[] topNScores, private TextView mMsText; private Module mModule; private String mModuleAssetName; + private FloatBuffer mInputTensorBuffer; + private Tensor mInputTensor; @Override protected int getContentViewLayoutId() { @@ -136,18 +139,24 @@ protected AnalysisResult analyzeImage(ImageProxy image, int rotationDegrees) { final String moduleFileAbsoluteFilePath = new File( Utils.assetFilePath(this, getModuleAssetName())).getAbsolutePath(); mModule = Module.load(moduleFileAbsoluteFilePath); + + mInputTensorBuffer = + Tensor.allocateFloatBuffer(3 * INPUT_TENSOR_WIDTH * INPUT_TENSOR_HEIGHT); + mInputTensor = Tensor.newFloat32Tensor( + new long[]{1, 3, INPUT_TENSOR_HEIGHT, INPUT_TENSOR_WIDTH}, + mInputTensorBuffer); } final long startTime = SystemClock.elapsedRealtime(); - final Tensor inputTensor = - TensorImageUtils.imageYUV420CenterCropToFloat32Tensor( - image.getImage(), rotationDegrees, - INPUT_TENSOR_WIDTH, INPUT_TENSOR_HEIGHT, - TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, - TensorImageUtils.TORCHVISION_NORM_STD_RGB); + TensorImageUtils.imageYUV420CenterCropToFloatBuffer( + image.getImage(), rotationDegrees, + INPUT_TENSOR_WIDTH, INPUT_TENSOR_HEIGHT, + TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, + TensorImageUtils.TORCHVISION_NORM_STD_RGB, + mInputTensorBuffer, 0); final long moduleForwardStartTime = SystemClock.elapsedRealtime(); - final Tensor outputTensor = mModule.forward(IValue.tensor(inputTensor)).getTensor(); + final Tensor outputTensor = mModule.forward(IValue.tensor(mInputTensor)).getTensor(); final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime; final float[] scores = outputTensor.getDataAsFloatArray();