Skip to content

Commit

Permalink
Update on "[quant][graphmode][fx] Support quantization for standalone…
Browse files Browse the repository at this point in the history
… module"

Summary:
Sometimes user need to quantize a submodule as one unit, and this submodule
will be lowered to a different backend like accelerator.

The submodule will be quantized with the same fx based graph mode quantization functions
and will be connected with the rest of the model automatically.

APIs:
```python
class StandaloneModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                return self.conv(x)

class CustomTracer(Tracer):
      def is_leaf_module(self, m, module_qualified_name):
          return (m.__module__.startswith('torch.nn') and
                     not isinstance(m, torch.nn.Sequential)) or \
                    isinstance(m, StandaloneModule)

class ModelThatUsesStandaloneModule(...):
      def __init__(self):
          super().__init__()
          self.standalone = StandaloneModule()

      def forward(self, x):
          return self.standalone(x)

m = ModelThatUsesStandaloneModule()
qconfig_dict = {"": qconfig, "standalone_module_name": ["standalone"]}
m = prepare_fx(m, qconfig_dict)
calibrate(m, data)
m = convert_fx(m)

m.standalone = lower_to_acclerator(m.standalone)
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23580642](https://our.internmc.facebook.com/intern/diff/D23580642)

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Sep 29, 2020
2 parents b6b9998 + 12fef82 commit 6587a4b
Show file tree
Hide file tree
Showing 114 changed files with 2,514 additions and 1,629 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ jobs:
export CIRCLE_SHA1="$CIRCLE_SHA1"
export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-}"
export CIRCLE_BRANCH="$CIRCLE_BRANCH"
export CIRCLE_JOB="$CIRCLE_JOB"
cd workspace
python test/print_test_stats.py test
EOL
Expand Down
1 change: 1 addition & 0 deletions .circleci/verbatim-sources/job-specs/pytorch-job-specs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ jobs:
export CIRCLE_SHA1="$CIRCLE_SHA1"
export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-}"
export CIRCLE_BRANCH="$CIRCLE_BRANCH"
export CIRCLE_JOB="$CIRCLE_JOB"
cd workspace
python test/print_test_stats.py test
EOL
Expand Down
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ if(LINUX)
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed")
endif()

if(WIN32 AND USE_DISTRIBUTED)
if(NOT DEFINED ENV{libuv_ROOT})
set(ENV{libuv_ROOT} $ENV{CONDA_PREFIX}\\Library)
endif()
endif()

if(MSVC)
foreach(flag_var
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# For reference:
# https://docs.docker.com/develop/develop-images/build_enhancements/
ARG BASE_IMAGE=ubuntu:18.04
ARG PYTHON_VERSION=3.7
ARG PYTHON_VERSION=3.8

FROM ${BASE_IMAGE} as dev-base
RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,9 @@ conda install pkg-config libuv

On Windows
```bash
# Add these packages and set libuv_ROOT environment variable if torch.distributed is needed.
# Add these packages if torch.distributed is needed.
# Distributed package support on Windows is a prototype feature and is subject to changes.
conda install -y -q -c rdonnelly libuv
set libuv_ROOT={conda active env location}\Library
```

#### Get the PyTorch Source
Expand Down
11 changes: 11 additions & 0 deletions android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@

namespace pytorch_jni {

c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) {
if (deviceJniCode == kDeviceCPU) {
return at::kCPU;
} else if (deviceJniCode == kDeviceVulkan) {
return at::kVulkan;
}

facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException, "Unknown device");
}

bool Trace::is_initialized_ = false;

#if defined(TRACE_ENABLED) && defined(__ANDROID__)
Expand Down
7 changes: 7 additions & 0 deletions android/pytorch_android/src/main/cpp/pytorch_jni_common.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include <fbjni/fbjni.h>
#include <torch/csrc/api/include/torch/types.h>

Expand All @@ -18,6 +20,11 @@

namespace pytorch_jni {

constexpr static int kDeviceCPU = 1;
constexpr static int kDeviceVulkan = 2;

c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode);

class Trace {
public:
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
Expand Down
36 changes: 28 additions & 8 deletions android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,25 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
private:
friend HybridBase;
torch::jit::Module module_;
c10::DeviceType deviceType_;

public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";

static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> modelPath) {
return makeCxxInstance(modelPath);
facebook::jni::alias_ref<jstring> modelPath,
jint device) {
return makeCxxInstance(modelPath, device);
}

#ifdef __ANDROID__
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager) {
return makeCxxInstance(assetName, assetManager);
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
return makeCxxInstance(assetName, assetManager, device);
}
#endif

Expand Down Expand Up @@ -127,17 +130,19 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
((void)once);
}

PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
PytorchJni(facebook::jni::alias_ref<jstring> modelPath, jint device) {
preModuleLoadSetup();
JITCallGuard guard;
module_ = torch::jit::load(std::move(modelPath->toStdString()));
module_.eval();
deviceType_ = deviceJniCodeToDeviceType(device);
}

#ifdef __ANDROID__
PytorchJni(
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager) {
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
preModuleLoadSetup();
JNIEnv* env = facebook::jni::Environment::current();
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
Expand Down Expand Up @@ -166,6 +171,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
assetBuffer, AAsset_getLength(asset)));
AAsset_close(asset);
module_.eval();
deviceType_ = deviceJniCodeToDeviceType(device);
}
#endif

Expand All @@ -191,7 +197,14 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
inputs.reserve(n);
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
if (at::kVulkan == deviceType_) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
inputs.push_back(std::move(atIValue));
}
}
auto output = [&]() {
JITCallGuard guard;
Expand All @@ -212,7 +225,14 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
inputs.reserve(n);
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
if (at::kVulkan == deviceType_) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
inputs.push_back(std::move(atIValue));
}
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
Expand Down
22 changes: 9 additions & 13 deletions android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ struct LiteJITCallGuard {
} // namespace

class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
constexpr static int kDeviceCPU = 1;
constexpr static int kDeviceVulkan = 2;

private:
friend HybridBase;
torch::jit::mobile::Module module_;
Expand All @@ -51,15 +48,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
PytorchJni(facebook::jni::alias_ref<jstring> modelPath, jint device) {
LiteJITCallGuard guard;
module_ = torch::jit::_load_for_mobile(std::move(modelPath->toStdString()));
if (device == kDeviceCPU) {
deviceType_ = at::kCPU;
} else if (device == kDeviceVulkan) {
deviceType_ = at::kVulkan;
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown device specified");
}
deviceType_ = deviceJniCodeToDeviceType(device);
}

static void registerNatives() {
Expand Down Expand Up @@ -108,7 +97,14 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
inputs.reserve(n);
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
if (at::kVulkan == deviceType_) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
inputs.push_back(std::move(atIValue));
}
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
Expand Down
17 changes: 14 additions & 3 deletions android/pytorch_android/src/main/java/org/pytorch/Module.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,27 @@ public class Module {
private INativePeer mNativePeer;

/**
* Loads a serialized TorchScript module from the specified path on the disk.
* Loads a serialized TorchScript module from the specified path on the disk to run on specified device.
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @param device {@link org.pytorch.Device} to use for running specified module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
*/
public static Module load(final String modelPath) {
public static Module load(final String modelPath, final Device device) {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
return new Module(new NativePeer(modelPath));
return new Module(new NativePeer(modelPath, device));
}

/**
* Loads a serialized TorchScript module from the specified path on the disk to run on CPU.
*
* @param modelPath path to file that contains the serialized TorchScript module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
*/
public static Module load(final String modelPath) {
return load(modelPath, Device.CPU);
}

Module(INativePeer nativePeer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ class NativePeer implements INativePeer {
private final HybridData mHybridData;

@DoNotStrip
private static native HybridData initHybrid(String moduleAbsolutePath);
private static native HybridData initHybrid(String moduleAbsolutePath, int deviceJniCode);

@DoNotStrip
private static native HybridData initHybridAndroidAsset(
String assetName, /* android.content.res.AssetManager */ Object androidAssetManager);
String assetName, /* android.content.res.AssetManager */ Object androidAssetManager, int deviceJniCode);

NativePeer(String moduleAbsolutePath) {
mHybridData = initHybrid(moduleAbsolutePath);
NativePeer(String moduleAbsolutePath, Device device) {
mHybridData = initHybrid(moduleAbsolutePath, device.jniCode);
}

NativePeer(String assetName, /* android.content.res.AssetManager */ Object androidAssetManager) {
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager);
NativePeer(String assetName, /* android.content.res.AssetManager */ Object androidAssetManager, Device device) {
mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode);
}

public void resetNative() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@ public final class PyTorchAndroid {
*
* <p>This method is meant to use in tests and demos.
*/
public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName, final Device device) {
return new Module(new NativePeer(assetName, assetManager, device));
}

public static Module loadModuleFromAsset(
final AssetManager assetManager, final String assetName) {
return new Module(new NativePeer(assetName, assetManager));
return new Module(new NativePeer(assetName, assetManager, Device.CPU));
}

/**
Expand Down
11 changes: 10 additions & 1 deletion android/test_app/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ android {
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
buildConfigField("boolean", "NATIVE_BUILD", 'false')
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
}
buildTypes {
Expand All @@ -66,9 +67,17 @@ android {
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"")
}
mbvulkan {
dimension "model"
applicationIdSuffix ".mbvulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"")
}
resnet18 {
dimension "model"
applicationIdSuffix ".resneti18"
applicationIdSuffix ".resnet18"
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"")
addManifestPlaceholders([APP_NAME: "RN18"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.FloatBuffer;
import org.pytorch.Device;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
Expand Down Expand Up @@ -126,7 +127,9 @@ protected Result doModuleForward() {
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE);
PyTorchAndroid.setNumThreads(1);
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
mModule = BuildConfig.USE_VULKAN_DEVICE
? PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
}

final long startTime = SystemClock.elapsedRealtime();
Expand Down
9 changes: 3 additions & 6 deletions aten/src/ATen/ThreadLocalState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ namespace at {

ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
: dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
debug_info_(c10::ThreadLocalDebugInfo::current()),
observers_enabled_(at::isRecordFunctionEnabled()) {
callbacks_ = _getTLSCallbacks();
debug_info_(c10::ThreadLocalDebugInfo::current()) {
rf_tls_ = at::get_record_function_tls_();

#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
keep_grad_mode_ = keep_grad_mode;
Expand All @@ -31,9 +30,7 @@ void ThreadLocalState::setThreadLocalState(
}
#endif

_setTLSCallbacks(state.callbacks_);

at::enableRecordFunction(state.observers_enabled_);
at::set_record_function_tls_(state.rf_tls_);

c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);

Expand Down
6 changes: 2 additions & 4 deletions aten/src/ATen/ThreadLocalState.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ class TORCH_API ThreadLocalState {
// with DebugInfoGuard
std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;

// RecordFunction TLS callbacks
RecordFunctionCallbacks callbacks_;

bool observers_enabled_ = false;
// RecordFunction TLS
RecordFunctionTLS rf_tls_;

#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
bool keep_grad_mode_ = true;
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace at { namespace native {
inline void alpha_check(const ScalarType dtype, Scalar alpha) {
TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
"Boolean alpha only supported for Boolean results.");
TORCH_CHECK(isFloatingType(dtype) || alpha.isIntegral(true),
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
|| alpha.isIntegral(true),
"For integral input tensors, argument alpha must not be a floating point number.");
}

Expand Down

0 comments on commit 6587a4b

Please sign in to comment.