Skip to content

Commit

Permalink
Update on "[quant][graphmode][fx] Merge all quantization mode"
Browse files Browse the repository at this point in the history
Summary:
This PR merges all quantization mode and will only expose the following top level functions:
```
prepare_fx
prepare_qat_fx
convert_fx
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23913105](https://our.internmc.facebook.com/intern/diff/D23913105)

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Sep 30, 2020
2 parents 8541de3 + 6e9abe7 commit 1d191e0
Show file tree
Hide file tree
Showing 44 changed files with 984 additions and 171 deletions.
26 changes: 20 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,29 @@ option(ONNX_ML "Enable traditional ONNX ML API." ON)
option(HAVE_SOVERSION "Whether to add SOVERSION to the shared objects" OFF)

# Since TensorPipe does not support Windows, set it to OFF when WIN32 detected
# On Windows platform, if user does not install libuv in build conda env and
# does not set libuv_ROOT environment variable. Set USE_DISTRIBUTED to OFF.
if(WIN32)
set(USE_TENSORPIPE OFF)
message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF")

if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT})
find_library(
libuv_tmp_LIBRARY
NAMES uv libuv
HINTS $ENV{CONDA_PREFIX}\\Library
PATH_SUFFIXES lib
REQUIRED
NO_DEFAULT_PATH)
if(NOT EXISTS ${libuv_tmp_LIBRARY})
set(USE_DISTRIBUTED OFF)
set(USE_GLOO OFF)
message(
WARNING "Libuv is not installed in current conda env. Set USE_DISTRIBUTED to OFF.")
else()
set(ENV{libuv_ROOT} $ENV{CONDA_PREFIX}\\Library)
endif()
endif()
endif()

# Linux distributions do not want too many embedded sources, in that sense we
Expand Down Expand Up @@ -292,12 +312,6 @@ if(LINUX)
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed")
endif()

if(WIN32 AND USE_DISTRIBUTED)
if(NOT DEFINED ENV{libuv_ROOT})
set(ENV{libuv_ROOT} $ENV{CONDA_PREFIX}\\Library)
endif()
endif()

if(MSVC)
foreach(flag_var
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4437,7 +4437,7 @@
use_c10_dispatcher: full
variants: function

- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (float, float)
- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor)
use_c10_dispatcher: full
variants: function

Expand Down
17 changes: 10 additions & 7 deletions aten/src/ATen/native/quantized/QTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,14 @@ float calculate_quant_loss(
float scale = data_range == 0
? 1.0
: static_cast<float>(static_cast<at::Half>(data_range / qmax));
float inverse_scale = 1.0f / scale;
float inverse_scale = scale == 0 ? 1.0f : 1.0f / scale;

float norm = 0.0f;
constexpr int VLEN = 8;
int i = 0;

// TODO add FBGEMM kernel
// #ifdef USE_FBGEMM
// #endif
// TODO add FBGEMM kernel
// #ifdef USE_FBGEMM
// #endif

// remainder loop
for (; i < numel; i++) {
Expand All @@ -271,7 +270,7 @@ float calculate_quant_loss(
and tries to minimize the quant error by doing `torch.norm(x-fake_quant(x,s,z))`
Returns the optimized xmax and xmin value of the tensor.
*/
std::tuple<double, double> choose_qparams_optimized(
std::tuple<Tensor, Tensor> choose_qparams_optimized(
const at::Tensor& input_tensor,
int64_t numel,
const int64_t n_bins,
Expand Down Expand Up @@ -318,7 +317,11 @@ std::tuple<double, double> choose_qparams_optimized(
}
}

return std::make_tuple((float) xmax, (float) xmin);
at::Tensor xmax_tensor = at::empty({1});
at::Tensor xmin_tensor = at::empty({1});
xmax_tensor[0] = xmax;
xmin_tensor[0] = xmin;
return std::make_tuple(xmax_tensor, xmin_tensor);
}
} // namespace native
} // namespace at
16 changes: 13 additions & 3 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,14 @@ Tensor _qembeddingbag_nbit_prepack_helper(

float Xmin, Xmax;
if (optimized_qparams) {
std::tie(Xmax, Xmin) = at::choose_qparams_optimized(
at::Tensor xmax_tensor, xmin_tensor;
std::tie(xmax_tensor, xmin_tensor) = at::choose_qparams_optimized(
weight_contig[row], embedding_cols, 200, 0.16, bit_width);
TORCH_CHECK(
xmax_tensor.numel() == 1 && xmin_tensor.numel() == 1,
"Expected choose_qparams_optimized to return min/max tensors of size 1");
Xmax = xmax_tensor.item<float>();
Xmin = xmin_tensor.item<float>();
} else {
Xmin = *std::min_element(input_row, input_row + embedding_cols);
Xmax = *std::max_element(input_row, input_row + embedding_cols);
Expand Down Expand Up @@ -254,7 +260,9 @@ Tensor _qembeddingbag_nbit_prepack_helper(
// To later de-quantize values, the scale (range / 15) and zero_point
// are stored alongside the data. More precisely, each row first has quantized
// values, and then 2-byte fp16 scale and 2-byte zero_offset.
Tensor qembeddingbag_4bit_prepack(const Tensor& weight, bool optimized_qparams) {
Tensor qembeddingbag_4bit_prepack(
const Tensor& weight,
bool optimized_qparams) {
return _qembeddingbag_nbit_prepack_helper(
weight, 4 /*bit_width*/, optimized_qparams);
}
Expand All @@ -267,7 +275,9 @@ Tensor qembeddingbag_4bit_prepack(const Tensor& weight, bool optimized_qparams)
// are stored alongside the data. More precisely, each row first has quantized
// values, and then 2-byte fp16 scale and 2-byte zero_offset.
// TODO() - Add 2Bit Embedding Lookup operator.
Tensor qembeddingbag_2bit_prepack(const Tensor& weight, bool optimized_qparams) {
Tensor qembeddingbag_2bit_prepack(
const Tensor& weight,
bool optimized_qparams) {
return _qembeddingbag_nbit_prepack_helper(
weight, 2 /*bit_width*/, optimized_qparams);
}
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/bucketize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ __global__ void BucketizeOpKernel(
CUDA_1D_KERNEL_LOOP(i, N) {
int32_t low = -1, high = M;
while (high - low > 1) {
int32_t median = (high + low) / 2;
const int32_t median = low + (high - low) / 2;
if (bounds[median] < X[i]) {
low = median;
} else {
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/roi_align_gradient_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void ROIAlignBackwardFeature(
} // namespace

template <>
bool RoIAlignGradientOp<float, CPUContext>::RunOnDevice() {
C10_EXPORT bool RoIAlignGradientOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0); // Input data to pool
auto& R = Input(1); // RoIs
auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/roi_align_gradient_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ __global__ void RoIAlignBackwardFeature(
} // namespace

template <>
bool RoIAlignGradientOp<float, CUDAContext>::RunOnDevice() {
C10_EXPORT bool RoIAlignGradientOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0); // Input data to pool
auto& R = Input(1); // RoIs
auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op
Expand Down
4 changes: 2 additions & 2 deletions caffe2/operators/roi_align_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ std::vector<BilinearInterpolationParam<T>> MakeBilinearInterpolationParams(
} // namespace

template <>
bool RoIAlignOp<float, CPUContext>::RunOnDeviceWithOrderNCHW(
C10_EXPORT bool RoIAlignOp<float, CPUContext>::RunOnDeviceWithOrderNCHW(
int64_t N,
int64_t C,
int64_t H,
Expand Down Expand Up @@ -170,7 +170,7 @@ bool RoIAlignOp<float, CPUContext>::RunOnDeviceWithOrderNCHW(
}

template <>
bool RoIAlignOp<float, CPUContext>::RunOnDeviceWithOrderNHWC(
C10_EXPORT bool RoIAlignOp<float, CPUContext>::RunOnDeviceWithOrderNHWC(
int64_t N,
int64_t C,
int64_t H,
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/roi_align_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ __global__ void RoIAlignForward(
} // namespace

template <>
bool RoIAlignOp<float, CUDAContext>::RunOnDevice() {
C10_EXPORT bool RoIAlignOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0); // Input data to pool
auto& R = Input(1); // RoIs
// RoI pooled data
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/roi_align_rotated_gradient_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ __global__ void RoIAlignRotatedBackward(
} // namespace

template <>
bool RoIAlignRotatedGradientOp<float, CUDAContext>::RunOnDevice() {
C10_EXPORT bool RoIAlignRotatedGradientOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0); // Input data to pool
auto& R = Input(1); // RoIs
auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/roi_align_rotated_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ void ROIAlignRotatedForward(
} // namespace

template <>
bool RoIAlignRotatedOp<float, CPUContext>::RunOnDevice() {
C10_EXPORT bool RoIAlignRotatedOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0); // Input data to pool
auto& R = Input(1); // RoIs

Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/roi_align_rotated_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ __global__ void RoIAlignRotatedForward(
} // namespace

template <>
bool RoIAlignRotatedOp<float, CUDAContext>::RunOnDevice() {
C10_EXPORT bool RoIAlignRotatedOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0); // Input data to pool
auto& R = Input(1); // RoIs

Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/roi_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using std::max;
using std::min;

template <>
bool RoIPoolOp<float, CPUContext>::RunOnDevice() {
C10_EXPORT bool RoIPoolOp<float, CPUContext>::RunOnDevice() {
const auto& X = Input(0); // Input data to pool
const auto& R = Input(1); // RoIs
auto* Y = Output(0); // RoI pooled data
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/roi_pool_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ bool RoIPoolOp<float, CUDAContext>::RunOnDevice() {
}

template <>
bool RoIPoolGradientOp<float, CUDAContext>::RunOnDevice() {
C10_EXPORT bool RoIPoolGradientOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0); // Input data to pool
auto& R = Input(1); // RoIs
auto& A = Input(2); // argmaxes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
("aten::_foreach_sub_", datetime.date(2020, 10, 1)),
("aten::_foreach_div", datetime.date(2020, 10, 1)),
("aten::_foreach_sub", datetime.date(2020, 10, 1)),
("aten::choose_qparams_optimized", datetime.date(2020, 10, 5)),
]


Expand Down
50 changes: 50 additions & 0 deletions test/cpp/api/functional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,56 @@ TEST_F(FunctionalTest, TripletMarginLoss) {
ASSERT_TRUE(output.allclose(expected, 1e-04));
}

TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) {
// Check that if we use torch::pairwise_distance with the default
// TripletMarginLoss options as our distance function, the outputs
// are equal (i.e., equal under defaults).

std::vector<TripletMarginWithDistanceLossOptions::reduction_t>
reductions = {torch::kSum, torch::kMean, torch::kNone};
std::vector<float> margins = {0.5, 1.0, 1.5};
std::vector<bool> swaps = {true, false};

for (auto& reduction : reductions) {
for (auto& margin : margins) {
for (const auto& swap : swaps) {
auto anchor =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

auto basicOptions = F::TripletMarginLossFuncOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
auto distanceOptions =
F::TripletMarginWithDistanceLossFuncOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
TripletMarginLoss basicLoss(basicOptions);
TripletMarginWithDistanceLoss distanceLoss(distanceOptions);

auto basicOutput =
F::triplet_margin_loss(anchor, positive, negative, basicOptions);
auto distanceOutput = F::triplet_margin_with_distance_loss(
anchor, positive, negative, distanceOptions);

ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));

// handle for torch::kNone reduction
auto sum = distanceOutput.sum();
sum.backward();
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
ASSERT_EQ(positive.sizes(), positive.grad().sizes());
ASSERT_EQ(negative.sizes(), negative.grad().sizes());
}
}
}
}

TEST_F(FunctionalTest, NLLLoss) {
auto input = torch::tensor({{-0.1315, -3.1315, -2.5315},
{-3.7038, -0.1038, -2.6038},
Expand Down

0 comments on commit 1d191e0

Please sign in to comment.