diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp index 98bbefc890e0..a62b8dbbd4e9 100644 --- a/aten/src/ATen/native/QuantizedLinear.cpp +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -5,6 +5,7 @@ #ifdef USE_FBGEMM #include "fbgemm/Fbgemm.h" +#include "fbgemm/FbgemmFP16.h" #include "fbgemm/QuantUtils.h" #endif // USE_FBGEMM @@ -22,6 +23,7 @@ namespace caffe2 { #ifdef USE_FBGEMM // Required for cpp_custom_type_hack to work CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix); +CAFFE_KNOWN_TYPE(fbgemm::PackedGemmMatrixFP16); #endif // USE_FBGEMM } @@ -41,24 +43,24 @@ Tensor fbgemm_linear_int8_weight( // We make a strong guarantee that models using these operators will have the // same numerics across different machines. Therefore, we do not provide a // fallback path and rather fail loudly if we cannot run FBGEMM. - AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); auto input_contig = input.contiguous(); auto* input_ptr = input_contig.data(); - AT_ASSERT(input.dim() >= 2); + TORCH_CHECK(input.dim() >= 2); int64_t M = 1; for (size_t i = 0; i < input.dim() - 1; ++i) { M *= input.size(i); } int64_t K = input.size(input.dim() - 1); - AT_ASSERT(weight.dim() == 2); - AT_ASSERT(K == weight.size(1)); + TORCH_CHECK(weight.dim() == 2); + TORCH_CHECK(K == weight.size(1)); auto N = weight.size(0); - AT_ASSERT(bias.dim() == 1); - AT_ASSERT(bias.size(0) == N); - AT_ASSERT(weight_scale.isFloatingPoint()); - AT_ASSERT(weight_zero_point.isIntegral()); + TORCH_CHECK(bias.dim() == 1); + TORCH_CHECK(bias.size(0) == N); + TORCH_CHECK(weight_scale.isFloatingPoint()); + TORCH_CHECK(weight_zero_point.isIntegral()); // Calculate statistics for quantization of the input Tensor float x_min, x_max; @@ -182,7 +184,7 @@ std::tuple fbgemm_linear_quantize_weight( // We make a strong guarantee that models using these operators will have the // same numerics across different machines. Therefore, we do not provide a // fallback path and rather fail loudly if we cannot run FBGEMM. - AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); auto weight_contig = weight.contiguous(); // Calculate weight statistics @@ -238,7 +240,7 @@ Tensor fbgemm_pack_quantized_matrix( // We make a strong guarantee that models using these operators will have the // same numerics across different machines. Therefore, we do not provide a // fallback path and rather fail loudly if we cannot run FBGEMM. - AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); auto weight_contig = weight.contiguous(); auto contiguous_ptr = weight_contig.data(); auto ptr = guts::make_unique>( @@ -252,6 +254,121 @@ Tensor fbgemm_pack_quantized_matrix( return cpp_custom_type_hack::create(std::move(ptr), weight.options()); } +float raw_uint16_to_fp16(unsigned short value) +{ + // Convert raw 16 bits half precision floating point number + // to single precision floating point number. + unsigned short sign_bits = value >> 15; + unsigned short exponent_bits = value >> 10 & 0x1f; + unsigned short significand_bits = value & 0x3ff; + + float sign = sign_bits ? -1 : 1; + float significand = 1 + significand_bits * 0x1p-10; + float exponent = exponent_bits - 0xf; + + return sign * std::ldexp(significand, exponent); +} + +template +bool check_and_saturate(T* element, T MAX) { + if (*element > MAX) { + *element = MAX; + return true; + } + if (*element < -MAX) { + *element = -MAX; + return true; + } + return false; +} + +// The range for using FP16 quantization of weights requires that the elements +// should be in the range of [5.96e-8, 65504]. If it is out of range, then the +// number will be saturated to max or min representable values by FP16. +void handle_weights_saturation(float* weight, int64_t length) { + float FP16_MAX = raw_uint16_to_fp16(0x7BFF); + bool found_out_of_range = false; + + for (int i = 0; i < length; ++i) { + if (check_and_saturate(&weight[i], FP16_MAX)) { + found_out_of_range = true; + } + } + + if (found_out_of_range) { + TORCH_WARN("FOUND weight out of range "); + } +} + +Tensor fbgemm_pack_gemm_matrix_fp16( + const Tensor& weight ) { + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); + + int64_t K = weight.size(1); + int64_t N = weight.size(0); + Tensor weight_contig = weight.contiguous(); + auto weight_contig_ptr = weight_contig.data(); + + handle_weights_saturation(weight_contig_ptr, K*N); + + // TODO(mingzhe09088): + // Consider using a functor here in PackedGemmMatrixFP16 + // Comments from (XQ): Not entirely sure this make_unique is safe. make_unique + // is created with regular "new", and freed through TypeMetaData::deleteFn in + // this function. This is perfectly fine if the tensors are created and freed + // within this translation unit. It might be very problematic if that tensor + // flows across dll boundaries. + auto ptr = guts::make_unique( + fbgemm::matrix_op_t::Transpose, + K, + N, + 1, + weight_contig_ptr); + return cpp_custom_type_hack::create(std::move(ptr), weight.options()); +} + +Tensor fbgemm_linear_fp16_weight( + const Tensor& input, + const Tensor& packed_weight, + const Tensor& bias) { + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); + + auto input_contig = input.contiguous(); + auto* input_ptr = input_contig.data(); + + // Pull out the PackedGemmMatrixFP16 instance from the owning tensor + const fbgemm::PackedGemmMatrixFP16& packed_weight_fp16 = + cpp_custom_type_hack::cast(packed_weight); + + TORCH_CHECK(input.size(1) == packed_weight_fp16.numRows()) + TORCH_CHECK(input.dim() >= 2); + int64_t M = input.size(0); + TORCH_CHECK(bias.dim() == 1); + int64_t N = packed_weight_fp16.numCols(); + + auto output = at::empty({M, N}, bias.options().dtype(at::kFloat)); + + // Call the fp16 gemm interface + fbgemm::cblas_gemm_compute( + fbgemm::matrix_op_t::NoTranspose, + M, + input_ptr, + packed_weight_fp16, + 0.f, + output.data()); + + // Add bias term + output.add_(bias); + + return output; +} + #else // USE_FBGEMM Tensor fbgemm_linear_int8_weight( @@ -265,7 +382,7 @@ Tensor fbgemm_linear_int8_weight( // We make a strong guarantee that models using these operators will have the // same numerics across different machines. Therefore, we do not provide a // fallback path and rather fail loudly if we cannot run FBGEMM. - AT_ASSERTM( + TORCH_CHECK( false, "This PyTorch installation was not built with FBGEMM operators"); } @@ -274,7 +391,7 @@ std::tuple fbgemm_linear_quantize_weight( // We make a strong guarantee that models using these operators will have the // same numerics across different machines. Therefore, we do not provide a // fallback path and rather fail loudly if we cannot run FBGEMM. - AT_ASSERTM( + TORCH_CHECK( false, "This PyTorch installation was not built with FBGEMM operators"); } @@ -285,7 +402,27 @@ Tensor fbgemm_pack_quantized_matrix( // We make a strong guarantee that models using these operators will have the // same numerics across different machines. Therefore, we do not provide a // fallback path and rather fail loudly if we cannot run FBGEMM. - AT_ASSERTM( + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +Tensor fbgemm_pack_gemm_matrix_fp16( + const Tensor& weight) { + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +Tensor fbgemm_linear_fp16_weight( + const Tensor& input, + const Tensor& packed_weight, + const Tensor& bias) { + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( false, "This PyTorch installation was not built with FBGEMM operators"); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index be679ec7848f..5b57baafb064 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1101,6 +1101,10 @@ - func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int) +- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor + +- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + - func: fbgemm_pack_quantized_matrix(Tensor input, int K, int N) -> Tensor - func: fbgemm_is_cpu_supported() -> bool diff --git a/test/test_nn.py b/test/test_nn.py index 45189fceddb0..d10b917ecc2c 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -36,6 +36,7 @@ module_tests, criterion_tests, loss_reference_fns, get_reduction, \ get_weight, smoothl1loss_reference, kldivloss_reference, \ ctcloss_reference, new_module_tests +from common_utils import TEST_WITH_UBSAN from torch.nn import MultiheadAttention @@ -2516,6 +2517,26 @@ def test_softmax_backward(self): def test_softmax_backward_cuda(self): self._test_softmax_backward(torch.device('cuda')) + @unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(), + 'Linear_FP16_weight requires FBGEMM. FBGEMM does not play' + ' well with UBSAN at the moment, so we skip the test if' + ' we are in a UBSAN environment.') + def test_fb_fc_packed(self): + X = np.random.rand(16, 16).astype(np.float32) - 0.5 + W = np.random.rand(16, 16).astype(np.float32) - 0.5 + b = np.random.rand(16).astype(np.float32) - 0.5 + + def fc_op(X, W, b): + return np.dot(X, W.T) + b + + x_tensor = torch.tensor(X) + w_tensor = torch.tensor(W) + b_tensor = torch.tensor(b) + packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor) + actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor) + expected_output = fc_op(X, W, b) + torch.testing.assert_allclose(expected_output, actual_output.cpu(), atol=1e-3, rtol=1e-3) + def _test_gumbel_softmax_st_shapes(self, cuda, dtype, shape, dim, count_expected): logits = torch.randn(shape, dtype=torch.float) logits = logits.to(dtype) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 5fbbf2933fb9..a5031c6be6ff 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -211,6 +211,9 @@ - name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor self: not_implemented("cholesky_inverse") +- name: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + self: not_implemented("fbgemm_linear_fp16_weight only supported for inference") + # For clamp, gradient is not defined at the boundaries. But empirically it's helpful # to be able to get gradient on min and max, so we return the subgradient 1 for these cases. - name: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor