Skip to content

Commit

Permalink
Update on "fix nn.MHA + quantized scriptability"
Browse files Browse the repository at this point in the history
Fixes a post-1.8 regression in nn.MultiheadAttention + quantization scriptability introduced in #52537. Passes the new test introduced in that PR, and fixes the repro found by @ngimel [here](https://gist.github.com/bhosmer/ef517d0774f2f10336b8140116fd6b62). 

Per comments in #52537 there's definitely a carnal dependency between quantization and the `_LinearWithBias` class by name that I'm reinstating here, but there may be cleaner ways to solve this - I don't really know what I'm doing 😁 . 

@jbschlosser @z-a-f LMK if you have ideas, happy to change this as desired. It'd be nice to get a fix into 1.9.

_[Update: now using a better name instead of `_LinearWithBias`, but this remains a short-term fix to re-suppress a quantization API usage error that should properly be raised upstream. See issue #58969]_

Differential Revision: [D28593830](https://our.internmc.facebook.com/intern/diff/D28593830)

[ghstack-poisoned]
  • Loading branch information
bhosmer committed May 26, 2021
2 parents c928f6a + be4ba29 commit 1940055
Show file tree
Hide file tree
Showing 88 changed files with 952 additions and 448 deletions.
4 changes: 4 additions & 0 deletions .circleci/scripts/binary_linux_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,9 @@ else
build_script='manywheel/build.sh'
fi

if [[ "$CIRCLE_BRANCH" == "master" ]] || [[ "$CIRCLE_BRANCH" == release/* ]]; then
export BUILD_DEBUG_INFO=1
fi

# Build the package
SKIP_ALL_TESTS=1 "/builder/$build_script"
8 changes: 8 additions & 0 deletions .github/templates/windows_ci_workflow.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ jobs:
steps:
- name: Checkout PyTorch
uses: actions/checkout@v2
- name: Install 7zip if not already installed
shell: powershell
run: |
choco install 7zip.install -y
- name: Install Visual Studio 2019 toolchain
shell: powershell
run: |
Expand Down Expand Up @@ -73,6 +77,10 @@ jobs:
steps:
- name: Checkout PyTorch
uses: actions/checkout@v2
- name: Install 7zip if not already installed
shell: powershell
run: |
choco install 7zip.install -y
- name: Install Visual Studio 2019 toolchain
shell: powershell
run: |
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/pytorch-win-vs2019-cpu-py3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ jobs:
steps:
- name: Checkout PyTorch
uses: actions/checkout@v2
- name: Install 7zip if not already installed
shell: powershell
run: |
choco install 7zip.install -y
- name: Install Visual Studio 2019 toolchain
shell: powershell
run: |
Expand Down Expand Up @@ -72,6 +76,10 @@ jobs:
steps:
- name: Checkout PyTorch
uses: actions/checkout@v2
- name: Install 7zip if not already installed
shell: powershell
run: |
choco install 7zip.install -y
- name: Install Visual Studio 2019 toolchain
shell: powershell
run: |
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ option(USE_SYSTEM_CPUINFO "Use system-provided cpuinfo." OFF)
option(USE_SYSTEM_SLEEF "Use system-provided sleef." OFF)
option(USE_SYSTEM_GLOO "Use system-provided gloo." OFF)
option(USE_SYSTEM_FP16 "Use system-provided fp16." OFF)
option(USE_SYSTEM_PYBIND11 "Use system-provided PyBind11." OFF)
option(USE_SYSTEM_PTHREADPOOL "Use system-provided pthreadpool." OFF)
option(USE_SYSTEM_PSIMD "Use system-provided psimd." OFF)
option(USE_SYSTEM_FXDIV "Use system-provided fxdiv." OFF)
Expand All @@ -371,6 +372,7 @@ if(USE_SYSTEM_LIBS)
set(USE_SYSTEM_BENCHMARK ON)
set(USE_SYSTEM_ONNX ON)
set(USE_SYSTEM_XNNPACK ON)
set(USE_SYSTEM_PYBIND11 ON)
endif()

# Used when building Caffe2 through setup.py
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/SparseTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ struct TORCH_API SparseTensorImpl : public TensorImpl {
// because many algorithms proceed by merging two sorted lists (of indices).
bool coalesced_ = false;

// compute_numel with integer multiplication overflow check, see gh-57542
void refresh_numel() {
TensorImpl::safe_refresh_numel();
}

public:
// Public for now...
explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
Expand Down
46 changes: 46 additions & 0 deletions aten/src/ATen/core/custom_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,51 @@ std::vector<c10::FunctionSchema> customClassSchemasForBCCheck() {
});
}

namespace detail {
class_base::class_base(
const std::string& namespaceName,
const std::string& className,
std::string doc_string,
const std::type_info& intrusivePtrClassTypeid,
const std::type_info& taggedCapsuleClassTypeid)
: qualClassName("__torch__.torch.classes." + namespaceName + '.' + className),
classTypePtr(at::ClassType::create(
c10::QualifiedName(qualClassName),
std::weak_ptr<jit::CompilationUnit>(),
/*is_module=*/false,
std::move(doc_string)))
{
detail::checkValidIdent(namespaceName, "Namespace name");
detail::checkValidIdent(className, "Class name");
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
c10::getCustomClassTypeMap().insert(
{std::type_index(intrusivePtrClassTypeid), classTypePtr});
c10::getCustomClassTypeMap().insert(
{std::type_index(taggedCapsuleClassTypeid), classTypePtr});

registerCustomClass(classTypePtr);
}

c10::FunctionSchema class_base::withNewArguments(
const c10::FunctionSchema& schema,
std::initializer_list<arg> default_args) {
const auto& old_args = schema.arguments();
std::vector<c10::Argument> new_args;
new_args.reserve(old_args.size());

new_args.emplace_back(old_args[0]);
// Skip self.
size_t argIdx = 1;
for (const auto& default_arg : default_args) {
auto& old_arg = old_args[argIdx++];
new_args.emplace_back(
default_arg.name_,
old_arg.type(),
old_arg.N(),
default_arg.value_);
}
return schema.cloneWithArguments(std::move(new_args));
}

} // namespace detail
} // namespace torch
2 changes: 1 addition & 1 deletion aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ CREATE_UNARY_FLOAT_META_FUNC(erfinv)
CREATE_UNARY_FLOAT_META_FUNC(exp)
CREATE_UNARY_FLOAT_META_FUNC(exp2)
CREATE_UNARY_FLOAT_META_FUNC(expm1)
CREATE_UNARY_FLOAT_META_FUNC(i0)
CREATE_UNARY_FLOAT_META_FUNC(lgamma)
CREATE_UNARY_FLOAT_META_FUNC(log)
CREATE_UNARY_FLOAT_META_FUNC(log10)
Expand Down Expand Up @@ -78,7 +79,6 @@ TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) {
}
CREATE_UNARY_META_FUNC(bitwise_not)
CREATE_UNARY_META_FUNC(frac)
CREATE_UNARY_META_FUNC(i0)
CREATE_UNARY_META_FUNC(round)
CREATE_UNARY_META_FUNC(sgn)

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void exp2_kernel_cuda(TensorIteratorBase& iter) {
}

void i0_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "i0_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_i0(a);
});
Expand Down
8 changes: 2 additions & 6 deletions c10/core/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class C10_API Scalar {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR)

// also support scalar.to<int64_t>();
// Deleted for unsupported types, but specialized below for supported types
template <typename T>
T to() const;
T to() const = delete;

#undef DEFINE_ACCESSOR
bool isFloatingPoint() const {
Expand Down Expand Up @@ -186,11 +187,6 @@ class C10_API Scalar {
};

// define the scalar.to<int64_t>() specializations
template <typename T>
inline T Scalar::to() const {
throw std::runtime_error("to() cast to unexpected type.");
}

#define DEFINE_TO(T, name) \
template <> \
inline T Scalar::to<T>() const { \
Expand Down
33 changes: 32 additions & 1 deletion c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return n;
}

/**
* Compute the number of elements based on the sizes of a
* tensor. Catches integer overflow that may occur when a tensor
* using a sparse layout has multiple dimensions with large sizes.
*/
int64_t safe_compute_numel() const {
int64_t n = 1;
for (auto s : sizes()) {
TORCH_CHECK(
s == 0 || n <= std::numeric_limits<int64_t>::max() / s,
"numel: integer multiplication overflow");
n *= s;
}
return n;
}

/**
* Compute whether or not a tensor is contiguous based on the sizes and
* strides of a tensor.
Expand All @@ -2041,12 +2057,27 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {

protected:
/**
* Recompute the cached numel of a tensor. Call this if you modify sizes.
* Recompute the cached numel of a tensor. Call this if you modify
* sizes.
*
* For tensors with sparse layouts, use safe_refresh_numel() instead
* because it will catch integer overflow that may occur for tensors
* with sparse layouts and large dimensions.
*/
void refresh_numel() {
numel_ = compute_numel();
}

/**
* Recompute the cached numel of a tensor. Call this if you modify
* sizes. Use only for tensors with sparse layouts because only
* sparse tensor are likely to have sizes that may lead to integer
* overflow when computing numel.
*/
void safe_refresh_numel() {
numel_ = safe_compute_numel();
}

/**
* Recompute the cached contiguity of a tensor. Call this if you modify sizes
* or strides.
Expand Down
2 changes: 1 addition & 1 deletion caffe2/python/operator_test/activation_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def leaky_relu_ref(X):
@given(X=hu.tensor(),
fast_gelu=st.booleans(),
**hu.gcs)
@settings(deadline=1000)
@settings(deadline=10000)
def test_gelu(self, X, fast_gelu, gc, dc):
op = core.CreateOperator(
"Gelu",
Expand Down
2 changes: 1 addition & 1 deletion caffe2/python/operator_test/adadelta_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def ref_adadelta(param_in,
decay=hu.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
**hu.gcs)
@settings(deadline=1000)
@settings(deadline=10000)
def test_adadelta(self, inputs, lr, epsilon, decay, gc, dc):
param, moment, moment_delta, grad = inputs
moment = np.abs(moment)
Expand Down
8 changes: 4 additions & 4 deletions caffe2/python/operator_test/adagrad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestAdagrad(serial.SerializedTestCase):
weight_decay=st.sampled_from([0.0, 0.1]),
**hu.gcs
)
@settings(deadline=1000)
@settings(deadline=10000)
def test_adagrad(self, inputs, lr, epsilon, weight_decay, gc, dc):
param, momentum, grad = inputs
momentum = np.abs(momentum)
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_adagrad_output_effective_lr(
),
**hu.gcs_cpu_only
)
@settings(deadline=1000)
@settings(deadline=10000)
def test_adagrad_output_effective_lr_and_update(self, inputs, lr, epsilon, gc, dc):
param, momentum, grad = inputs
momentum = np.abs(momentum)
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_sparse_adagrad(self, inputs, lr, epsilon, weight_decay, gc, dc):
),
**hu.gcs
)
@settings(deadline=1000)
@settings(deadline=10000)
def test_sparse_adagrad_empty(self, inputs, lr, epsilon, gc, dc):
param, momentum = inputs
grad = np.empty(shape=(0,) + param.shape[1:], dtype=np.float32)
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_sparse_adagrad_empty(self, inputs, lr, epsilon, gc, dc):

# Suppress filter_too_much health check.
# Likely caused by `assume` call falling through too often.
@settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=1000)
@settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=10000)
@given(
inputs=hu.tensors(n=3),
lr=st.floats(
Expand Down
2 changes: 1 addition & 1 deletion caffe2/python/operator_test/assert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TestAssert(hu.HypothesisTestCase):
dtype=st.sampled_from(['bool_', 'int32', 'int64']),
shape=st.lists(elements=st.integers(1, 10), min_size=1, max_size=4),
**hu.gcs)
@settings(deadline=1000)
@settings(deadline=10000)
def test_assert(self, dtype, shape, gc, dc):
test_tensor = np.random.rand(*shape).astype(np.dtype(dtype))

Expand Down
4 changes: 2 additions & 2 deletions caffe2/python/operator_test/batch_sparse_to_dense_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestBatchSparseToDense(serial.SerializedTestCase):
default_value=st.floats(min_value=2.0, max_value=3.0),
**hu.gcs
)
@settings(deadline=1000)
@settings(deadline=None)
def test_batch_sparse_to_dense(
self, batch_size, dense_last_dim, default_value, gc, dc
):
Expand Down Expand Up @@ -75,7 +75,7 @@ def batch_sparse_to_dense_ref(L, I, V, S=None):
dense_last_dim=st.integers(5, 10),
**hu.gcs
)
@settings(deadline=1000)
@settings(deadline=None)
def test_batch_dense_to_sparse(self, batch_size, dense_last_dim, gc, dc):
L = np.random.randint(1, dense_last_dim + 1, size=(batch_size))
# The following logic ensure that indices in each batch will not be duplicated
Expand Down
4 changes: 2 additions & 2 deletions caffe2/python/operator_test/bbox_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class TestBBoxTransformOp(serial.SerializedTestCase):
clip_angle_thresh=st.sampled_from([-1.0, 1.0]),
**hu.gcs_cpu_only
)
@settings(deadline=1000)
@settings(deadline=10000)
def test_bbox_transform(
self,
num_rois,
Expand Down Expand Up @@ -282,7 +282,7 @@ def bbox_transform_ref(rois, deltas, im_info):
clip_angle_thresh=st.sampled_from([-1.0, 1.0]),
**hu.gcs_cpu_only
)
@settings(deadline=1000)
@settings(deadline=10000)
def test_bbox_transform_batch(
self,
roi_counts,
Expand Down
4 changes: 2 additions & 2 deletions caffe2/python/operator_test/boolean_mask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TestBooleanMaskOp(serial.SerializedTestCase):
max_len=100,
elements=hu.floats(min_value=0.5, max_value=1.0)),
**hu.gcs_cpu_only)
@settings(deadline=1000)
@settings(deadline=10000)
def test_boolean_mask_gradient(self, x, gc, dc):
op = core.CreateOperator("BooleanMask",
["data", "mask"],
Expand All @@ -30,7 +30,7 @@ def test_boolean_mask_gradient(self, x, gc, dc):
max_len=5,
elements=hu.floats(min_value=0.5, max_value=1.0)),
**hu.gcs)
@settings(deadline=1000)
@settings(deadline=10000)
def test_boolean_mask(self, x, gc, dc):
op = core.CreateOperator("BooleanMask",
["data", "mask"],
Expand Down
6 changes: 3 additions & 3 deletions caffe2/python/operator_test/box_with_nms_limit_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def ref(*args, **kwargs):
self.assertReferenceChecks(gc, op, [scores, boxes], ref)

@given(**HU_CONFIG)
@settings(deadline=1000)
@settings(deadline=10000)
def test_score_thresh(self, gc):
in_centers = [(0, 0), (20, 20), (50, 50)]
in_scores = [0.7, 0.85, 0.6]
Expand All @@ -102,7 +102,7 @@ def ref(*args, **kwargs):
self.assertReferenceChecks(gc, op, [scores, boxes], ref)

@given(det_per_im=st.integers(1, 3), **HU_CONFIG)
@settings(deadline=1000)
@settings(deadline=10000)
def test_detections_per_im(self, det_per_im, gc):
in_centers = [(0, 0), (20, 20), (50, 50)]
in_scores = [0.7, 0.85, 0.6]
Expand Down Expand Up @@ -131,7 +131,7 @@ def ref(*args, **kwargs):
output_classes_include_bg_cls=st.booleans(),
**HU_CONFIG
)
@settings(deadline=1000)
@settings(deadline=10000)
def test_multiclass(
self,
num_classes,
Expand Down
2 changes: 1 addition & 1 deletion caffe2/python/operator_test/clip_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestClip(serial.SerializedTestCase):
max_=st.floats(min_value=0, max_value=2),
inplace=st.booleans(),
**hu.gcs)
@settings(deadline=1000)
@settings(deadline=10000)
def test_clip(self, X, min_, max_, inplace, gc, dc):
# go away from the origin point to avoid kink problems
if np.isscalar(X):
Expand Down
2 changes: 1 addition & 1 deletion caffe2/python/operator_test/clip_tensor_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestClipTensorByScalingOp(serial.SerializedTestCase):
use_additional_threshold=st.booleans(),
inplace=st.booleans(),
**hu.gcs_cpu_only)
@settings(deadline=1000)
@settings(deadline=10000)
def test_clip_tensor_by_scaling(self, n, d, threshold, additional_threshold,
use_additional_threshold, inplace, gc, dc):

Expand Down

0 comments on commit 1940055

Please sign in to comment.