Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 150 additions & 13 deletions aten/src/ATen/native/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#ifdef USE_FBGEMM
#include "fbgemm/Fbgemm.h"
#include "fbgemm/FbgemmFP16.h"
#include "fbgemm/QuantUtils.h"
#endif // USE_FBGEMM

Expand All @@ -22,6 +23,7 @@ namespace caffe2 {
#ifdef USE_FBGEMM
// Required for cpp_custom_type_hack to work
CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>);
CAFFE_KNOWN_TYPE(fbgemm::PackedGemmMatrixFP16);
#endif // USE_FBGEMM
}

Expand All @@ -41,24 +43,24 @@ Tensor fbgemm_linear_int8_weight(
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");

auto input_contig = input.contiguous();
auto* input_ptr = input_contig.data<float>();

AT_ASSERT(input.dim() >= 2);
TORCH_CHECK(input.dim() >= 2);
int64_t M = 1;
for (size_t i = 0; i < input.dim() - 1; ++i) {
M *= input.size(i);
}
int64_t K = input.size(input.dim() - 1);
AT_ASSERT(weight.dim() == 2);
AT_ASSERT(K == weight.size(1));
TORCH_CHECK(weight.dim() == 2);
TORCH_CHECK(K == weight.size(1));
auto N = weight.size(0);
AT_ASSERT(bias.dim() == 1);
AT_ASSERT(bias.size(0) == N);
AT_ASSERT(weight_scale.isFloatingPoint());
AT_ASSERT(weight_zero_point.isIntegral());
TORCH_CHECK(bias.dim() == 1);
TORCH_CHECK(bias.size(0) == N);
TORCH_CHECK(weight_scale.isFloatingPoint());
TORCH_CHECK(weight_zero_point.isIntegral());

// Calculate statistics for quantization of the input Tensor
float x_min, x_max;
Expand Down Expand Up @@ -182,7 +184,7 @@ std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
auto weight_contig = weight.contiguous();

// Calculate weight statistics
Expand Down Expand Up @@ -238,7 +240,7 @@ Tensor fbgemm_pack_quantized_matrix(
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
auto weight_contig = weight.contiguous();
auto contiguous_ptr = weight_contig.data<int8_t>();
auto ptr = guts::make_unique<fbgemm::PackBMatrix<int8_t>>(
Expand All @@ -252,6 +254,121 @@ Tensor fbgemm_pack_quantized_matrix(
return cpp_custom_type_hack::create(std::move(ptr), weight.options());
}

float raw_uint16_to_fp16(unsigned short value)
{
// Convert raw 16 bits half precision floating point number
// to single precision floating point number.
unsigned short sign_bits = value >> 15;
unsigned short exponent_bits = value >> 10 & 0x1f;
unsigned short significand_bits = value & 0x3ff;

float sign = sign_bits ? -1 : 1;
float significand = 1 + significand_bits * 0x1p-10;
float exponent = exponent_bits - 0xf;

return sign * std::ldexp(significand, exponent);
}

template <typename T>
bool check_and_saturate(T* element, T MAX) {
if (*element > MAX) {
*element = MAX;
return true;
}
if (*element < -MAX) {
*element = -MAX;
return true;
}
return false;
}

// The range for using FP16 quantization of weights requires that the elements
// should be in the range of [5.96e-8, 65504]. If it is out of range, then the
// number will be saturated to max or min representable values by FP16.
void handle_weights_saturation(float* weight, int64_t length) {
float FP16_MAX = raw_uint16_to_fp16(0x7BFF);
bool found_out_of_range = false;

for (int i = 0; i < length; ++i) {
if (check_and_saturate<float>(&weight[i], FP16_MAX)) {
found_out_of_range = true;
}
}

if (found_out_of_range) {
TORCH_WARN("FOUND weight out of range ");
}
}

Tensor fbgemm_pack_gemm_matrix_fp16(
const Tensor& weight ) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");

int64_t K = weight.size(1);
int64_t N = weight.size(0);
Tensor weight_contig = weight.contiguous();
auto weight_contig_ptr = weight_contig.data<float>();

handle_weights_saturation(weight_contig_ptr, K*N);

// TODO(mingzhe09088):
// Consider using a functor here in PackedGemmMatrixFP16
// Comments from (XQ): Not entirely sure this make_unique is safe. make_unique
// is created with regular "new", and freed through TypeMetaData::deleteFn in
// this function. This is perfectly fine if the tensors are created and freed
// within this translation unit. It might be very problematic if that tensor
// flows across dll boundaries.
auto ptr = guts::make_unique<fbgemm::PackedGemmMatrixFP16>(
fbgemm::matrix_op_t::Transpose,
K,
N,
1,
weight_contig_ptr);
return cpp_custom_type_hack::create(std::move(ptr), weight.options());
}

Tensor fbgemm_linear_fp16_weight(
const Tensor& input,
const Tensor& packed_weight,
const Tensor& bias) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");

auto input_contig = input.contiguous();
auto* input_ptr = input_contig.data<float>();

// Pull out the PackedGemmMatrixFP16 instance from the owning tensor
const fbgemm::PackedGemmMatrixFP16& packed_weight_fp16 =
cpp_custom_type_hack::cast<fbgemm::PackedGemmMatrixFP16>(packed_weight);

TORCH_CHECK(input.size(1) == packed_weight_fp16.numRows())
TORCH_CHECK(input.dim() >= 2);
int64_t M = input.size(0);
TORCH_CHECK(bias.dim() == 1);
int64_t N = packed_weight_fp16.numCols();

auto output = at::empty({M, N}, bias.options().dtype(at::kFloat));

// Call the fp16 gemm interface
fbgemm::cblas_gemm_compute(
fbgemm::matrix_op_t::NoTranspose,
M,
input_ptr,
packed_weight_fp16,
0.f,
output.data<float>());

// Add bias term
output.add_(bias);

return output;
}

#else // USE_FBGEMM

Tensor fbgemm_linear_int8_weight(
Expand All @@ -265,7 +382,7 @@ Tensor fbgemm_linear_int8_weight(
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}

Expand All @@ -274,7 +391,7 @@ std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}

Expand All @@ -285,7 +402,27 @@ Tensor fbgemm_pack_quantized_matrix(
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
AT_ASSERTM(
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}

Tensor fbgemm_pack_gemm_matrix_fp16(
const Tensor& weight) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}

Tensor fbgemm_linear_fp16_weight(
const Tensor& input,
const Tensor& packed_weight,
const Tensor& bias) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}

Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,10 @@

- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)

- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor

- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor

- func: fbgemm_pack_quantized_matrix(Tensor input, int K, int N) -> Tensor

- func: fbgemm_is_cpu_supported() -> bool
Expand Down
21 changes: 21 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
module_tests, criterion_tests, loss_reference_fns, get_reduction, \
get_weight, smoothl1loss_reference, kldivloss_reference, \
ctcloss_reference, new_module_tests
from common_utils import TEST_WITH_UBSAN

from torch.nn import MultiheadAttention

Expand Down Expand Up @@ -2516,6 +2517,26 @@ def test_softmax_backward(self):
def test_softmax_backward_cuda(self):
self._test_softmax_backward(torch.device('cuda'))

@unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
'Linear_FP16_weight requires FBGEMM. FBGEMM does not play'
' well with UBSAN at the moment, so we skip the test if'
' we are in a UBSAN environment.')
def test_fb_fc_packed(self):
X = np.random.rand(16, 16).astype(np.float32) - 0.5
W = np.random.rand(16, 16).astype(np.float32) - 0.5
b = np.random.rand(16).astype(np.float32) - 0.5

def fc_op(X, W, b):
return np.dot(X, W.T) + b

x_tensor = torch.tensor(X)
w_tensor = torch.tensor(W)
b_tensor = torch.tensor(b)
packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor)
actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor)
expected_output = fc_op(X, W, b)
torch.testing.assert_allclose(expected_output, actual_output.cpu(), atol=1e-3, rtol=1e-3)

def _test_gumbel_softmax_st_shapes(self, cuda, dtype, shape, dim, count_expected):
logits = torch.randn(shape, dtype=torch.float)
logits = logits.to(dtype)
Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@
- name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor
self: not_implemented("cholesky_inverse")

- name: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
self: not_implemented("fbgemm_linear_fp16_weight only supported for inference")

# For clamp, gradient is not defined at the boundaries. But empirically it's helpful
# to be able to get gradient on min and max, so we return the subgradient 1 for these cases.
- name: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
Expand Down