Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[android][vulkan] Module load argument to specify device cpu/vulkan #44896

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@

namespace pytorch_jni {

c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) {
if (deviceJniCode == kDeviceCPU) {
return at::kCPU;
} else if (deviceJniCode == kDeviceVulkan) {
return at::kVulkan;
}

facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException, "Unknown device");
}

bool Trace::is_initialized_ = false;

#if defined(TRACE_ENABLED) && defined(__ANDROID__)
Expand Down
7 changes: 7 additions & 0 deletions android/pytorch_android/src/main/cpp/pytorch_jni_common.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include <fbjni/fbjni.h>
#include <torch/csrc/api/include/torch/types.h>

Expand All @@ -18,6 +20,11 @@

namespace pytorch_jni {

constexpr static int kDeviceCPU = 1;
constexpr static int kDeviceVulkan = 2;

c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode);

class Trace {
public:
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
Expand Down
36 changes: 28 additions & 8 deletions android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,25 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
private:
friend HybridBase;
torch::jit::Module module_;
c10::DeviceType deviceType_;

public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";

static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> modelPath) {
return makeCxxInstance(modelPath);
facebook::jni::alias_ref<jstring> modelPath,
jint device) {
return makeCxxInstance(modelPath, device);
}

#ifdef __ANDROID__
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager) {
return makeCxxInstance(assetName, assetManager);
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
return makeCxxInstance(assetName, assetManager, device);
}
#endif

Expand Down Expand Up @@ -127,17 +130,19 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
((void)once);
}

PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
PytorchJni(facebook::jni::alias_ref<jstring> modelPath, jint device) {
preModuleLoadSetup();
JITCallGuard guard;
module_ = torch::jit::load(std::move(modelPath->toStdString()));
module_.eval();
deviceType_ = deviceJniCodeToDeviceType(device);
}

#ifdef __ANDROID__
PytorchJni(
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager) {
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
preModuleLoadSetup();
JNIEnv* env = facebook::jni::Environment::current();
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
Expand Down Expand Up @@ -166,6 +171,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
assetBuffer, AAsset_getLength(asset)));
AAsset_close(asset);
module_.eval();
deviceType_ = deviceJniCodeToDeviceType(device);
}
#endif

Expand All @@ -191,7 +197,14 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
inputs.reserve(n);
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
if (at::kVulkan == deviceType_) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
inputs.push_back(std::move(atIValue));
}
}
auto output = [&]() {
JITCallGuard guard;
Expand All @@ -212,7 +225,14 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
inputs.reserve(n);
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
if (at::kVulkan == deviceType_) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
inputs.push_back(std::move(atIValue));
}
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
Expand Down
22 changes: 9 additions & 13 deletions android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ struct LiteJITCallGuard {
} // namespace

class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
constexpr static int kDeviceCPU = 1;
constexpr static int kDeviceVulkan = 2;

private:
friend HybridBase;
torch::jit::mobile::Module module_;
Expand All @@ -51,15 +48,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
PytorchJni(facebook::jni::alias_ref<jstring> modelPath, jint device) {
LiteJITCallGuard guard;
module_ = torch::jit::_load_for_mobile(std::move(modelPath->toStdString()));
if (device == kDeviceCPU) {
deviceType_ = at::kCPU;
} else if (device == kDeviceVulkan) {
deviceType_ = at::kVulkan;
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown device specified");
}
deviceType_ = deviceJniCodeToDeviceType(device);
}

static void registerNatives() {
Expand Down Expand Up @@ -108,7 +97,14 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
inputs.reserve(n);
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
if (at::kVulkan == deviceType_) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
inputs.push_back(std::move(atIValue));
}
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
Expand Down
17 changes: 14 additions & 3 deletions android/pytorch_android/src/main/java/org/pytorch/Module.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,27 @@ public class Module {
private INativePeer mNativePeer;

/**
* Loads a serialized TorchScript module from the specified path on the disk.
* Loads a serialized TorchScript module from the specified path on the disk to run on specified device.
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @param device {@link org.pytorch.Device} to use for running specified module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
*/
public static Module load(final String modelPath) {
public static Module load(final String modelPath, final Device device) {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
return new Module(new NativePeer(modelPath));
return new Module(new NativePeer(modelPath, device));
}

/**
* Loads a serialized TorchScript module from the specified path on the disk to run on CPU.
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
*/
public static Module load(final String modelPath) {
return load(modelPath, Device.CPU);
}

Module(INativePeer nativePeer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ class NativePeer implements INativePeer {
private final HybridData mHybridData;

@DoNotStrip
private static native HybridData initHybrid(String moduleAbsolutePath);
private static native HybridData initHybrid(String moduleAbsolutePath, int deviceJniCode);

@DoNotStrip
private static native HybridData initHybridAndroidAsset(
String assetName, /* android.content.res.AssetManager */ Object androidAssetManager);
String assetName, /* android.content.res.AssetManager */ Object androidAssetManager, int deviceJniCode);

NativePeer(String moduleAbsolutePath) {
mHybridData = initHybrid(moduleAbsolutePath);
NativePeer(String moduleAbsolutePath, Device device) {
mHybridData = initHybrid(moduleAbsolutePath, device.jniCode);
}

NativePeer(String assetName, /* android.content.res.AssetManager */ Object androidAssetManager) {
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager);
NativePeer(String assetName, /* android.content.res.AssetManager */ Object androidAssetManager, Device device) {
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode);
}

public void resetNative() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@ public final class PyTorchAndroid {
*
* <p>This method is meant to use in tests and demos.
*/
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName, final Device device) {
return new Module(new NativePeer(assetName, assetManager, device));
}

public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName) {
return new Module(new NativePeer(assetName, assetManager));
return new Module(new NativePeer(assetName, assetManager, Device.CPU));
}

/**
Expand Down