Skip to content

Commit

Permalink
[vulkan][android][test_app] Add test_app variant that runs module on …
Browse files Browse the repository at this point in the history
…Vulkan (#44897)

Summary: Pull Request resolved: #44897

Test Plan: Imported from OSS

Reviewed By: dreiss

Differential Revision: D23763770

Pulled By: IvanKobzarev

fbshipit-source-id: 6ad16b7271c745313a71da64a629a764258bbc85
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Sep 29, 2020
1 parent 2c300fd commit 17be7c6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
11 changes: 10 additions & 1 deletion android/test_app/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ android {
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
buildConfigField("boolean", "NATIVE_BUILD", 'false')
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
}
buildTypes {
Expand All @@ -66,9 +67,17 @@ android {
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"")
}
mbvulkan {
dimension "model"
applicationIdSuffix ".mbvulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"")
}
resnet18 {
dimension "model"
applicationIdSuffix ".resneti18"
applicationIdSuffix ".resnet18"
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"")
addManifestPlaceholders([APP_NAME: "RN18"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.FloatBuffer;
import org.pytorch.Device;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
Expand Down Expand Up @@ -126,7 +127,9 @@ protected Result doModuleForward() {
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE);
PyTorchAndroid.setNumThreads(1);
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
mModule = BuildConfig.USE_VULKAN_DEVICE
? PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
}

final long startTime = SystemClock.elapsedRealtime();
Expand Down

0 comments on commit 17be7c6

Please sign in to comment.