Skip to content

Commit

Permalink
[Caffe2] use the real new fbgemm sparse adagrad interface (#46132)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #46132

As title

Test Plan: .

Reviewed By: dskhudia

Differential Revision: D24197694

fbshipit-source-id: 2bfe8f52409fa500d2ea359dec7f521cffb20efb
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Oct 10, 2020
1 parent 9f74301 commit 4c87d33
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions caffe2/sgd/adagrad_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,14 @@ class SparseAdagradOp final : public Operator<CPUContext> {
if (block_size != last_block_size_) {
last_block_size_ = block_size;
if (std::is_same<SIndex, std::int32_t>::value) {
kernel_i32_ = fbgemm::GenerateSparseAdaGradNew<std::int32_t>(
kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t>(
block_size,
/*rowwise=*/false,
/*prefetch=*/16,
weight_decay_ != 0.0f);
} else {
CAFFE_ENFORCE((std::is_same<SIndex, std::int64_t>::value));
kernel_i64_ = fbgemm::GenerateSparseAdaGradNew<std::int64_t>(
kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t>(
block_size,
/*rowwise=*/false,
/*prefetch=*/16,
Expand Down Expand Up @@ -350,8 +350,8 @@ class SparseAdagradOp final : public Operator<CPUContext> {
float epsilon_;
const float weight_decay_;
#if defined(USE_FBGEMM) && !defined(__NVCC__)
fbgemm::SparseAdaGradSignature<std::int32_t>::NewType kernel_i32_;
fbgemm::SparseAdaGradSignature<std::int64_t>::NewType kernel_i64_;
fbgemm::SparseAdaGradSignature<std::int32_t>::Type kernel_i32_;
fbgemm::SparseAdaGradSignature<std::int64_t>::Type kernel_i64_;
std::int64_t last_block_size_{-1};
#endif

Expand Down Expand Up @@ -428,14 +428,14 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
if (block_size != last_block_size_) {
last_block_size_ = block_size;
if (std::is_same<SIndex, std::int32_t>::value) {
kernel_i32_ = fbgemm::GenerateSparseAdaGradNew<std::int32_t>(
kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t>(
block_size,
/*rowwise=*/true,
/*prefetch=*/16,
weight_decay_ != 0.0f);
} else {
CAFFE_ENFORCE((std::is_same<SIndex, std::int64_t>::value));
kernel_i64_ = fbgemm::GenerateSparseAdaGradNew<std::int64_t>(
kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t>(
block_size,
/*rowwise=*/true,
/*prefetch=*/16,
Expand Down Expand Up @@ -545,8 +545,8 @@ class RowWiseSparseAdagradOp final : public Operator<Context> {
float epsilon_;
const float weight_decay_;
#if defined(USE_FBGEMM) && !defined(__NVCC__)
fbgemm::SparseAdaGradSignature<std::int32_t>::NewType kernel_i32_;
fbgemm::SparseAdaGradSignature<std::int64_t>::NewType kernel_i64_;
fbgemm::SparseAdaGradSignature<std::int32_t>::Type kernel_i32_;
fbgemm::SparseAdaGradSignature<std::int64_t>::Type kernel_i64_;
std::int64_t last_block_size_{-1};
#endif

Expand Down

0 comments on commit 4c87d33

Please sign in to comment.