Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions aten/src/ATen/core/typeid.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,23 +410,23 @@ inline bool operator!=(const TypeMeta& lhs, const TypeMeta& rhs) noexcept {
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0537r0.html
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51930
// and as a result, we define these two macros slightly differently.
#ifdef _MSC_VER
#if defined(_MSC_VER) || defined(__clang__)
#define CAFFE_KNOWN_TYPE(T) \
template <> \
C10_EXPORT TypeIdentifier TypeMeta::Id<T>() { \
static const TypeIdentifier type_id = TypeIdentifier::createTypeId(); \
static TypeNameRegisterer<T> registerer(type_id, #T); \
return type_id; \
}
#else // _MSC_VER
#else // defined(_MSC_VER) || defined(__clang__)
#define CAFFE_KNOWN_TYPE(T) \
template <> \
TypeIdentifier TypeMeta::Id<T>() { \
static const TypeIdentifier type_id = TypeIdentifier::createTypeId(); \
static TypeNameRegisterer<T> registerer(type_id, #T); \
return type_id; \
}
#endif
#endif // defined(_MSC_VER) || defined(__clang__)

/**
* CAFFE_DECLARE_KNOWN_TYPE and CAFFE_DEFINE_KNOWN_TYPE are used
Expand Down
5 changes: 1 addition & 4 deletions binaries/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ if (USE_OPENCV)
endif()

if (USE_OBSERVERS)
add_executable(caffe2_benchmark "caffe2_benchmark.cc" "benchmark_helper.cc")
target_link_libraries(caffe2_benchmark ${Caffe2_MAIN_LIBS})
target_link_libraries(caffe2_benchmark ${Caffe2_MODULES})
install(TARGETS caffe2_benchmark DESTINATION bin)
caffe2_binary_target(caffe2_benchmark "caffe2_benchmark.cc" "benchmark_helper.cc")
endif()

# ---[ tutorials
Expand Down
13 changes: 9 additions & 4 deletions binaries/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "caffe2/core/logging.h"
#include "caffe2/core/net.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor_int8.h"
#include "caffe2/utils/bench_utils.h"
#include "caffe2/utils/string_utils.h"
#include "observers/net_observer_reporter_print.h"
Expand Down Expand Up @@ -163,12 +164,16 @@ void loadInput(
CAFFE_THROW("Not support GPU on mobile.");
#endif
} else {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
if (input_type_list[i] == "uint8_t") {
tensor->mutable_data<uint8_t>();
caffe2::int8::Int8TensorCPU* tensor =
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
CHECK_NOTNULL(tensor);
tensor->t.Resize(input_dims);
tensor->t.mutable_data<uint8_t>();
} else if (input_type_list[i] == "float") {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
tensor->mutable_data<float>();
} else {
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
Expand Down
15 changes: 10 additions & 5 deletions binaries/speed_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "caffe2/core/init.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor_int8.h"
#ifdef CAFFE2_OPTIMIZER
#include "caffe2/opt/optimizer.h"
#endif
Expand Down Expand Up @@ -137,14 +138,18 @@ int main(int argc, char** argv) {
if (blob == nullptr) {
blob = workspace->CreateBlob(input_names[i]);
}
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
if (input_type_list[i] == "uint8_t") {
tensor->mutable_data<uint8_t>();
caffe2::int8::Int8TensorCPU* tensor =
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
CHECK_NOTNULL(tensor);
tensor->t.Resize(input_dims);
tensor->t.mutable_data<uint8_t>();
} else if (input_type_list[i] == "float") {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
tensor->mutable_data<float>();
} else {
} else {
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
}
}
Expand Down
7 changes: 0 additions & 7 deletions caffe2/operators/batch_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,6 @@ class GetBatchMatMulGradient : public GradientMakerBase {
auto trans_both_arg = vector<Argument>{MakeArgument<int>("trans_a", 1),
MakeArgument<int>("trans_b", 1)};

if (ArgumentHelper::HasArgument(Def(), "use_scratch")) {
no_trans_arg.push_back(MakeArgument<int>("use_scratch", 1));
trans_a_arg.push_back(MakeArgument<int>("use_scratch", 1));
trans_b_arg.push_back(MakeArgument<int>("use_scratch", 1));
trans_both_arg.push_back(MakeArgument<int>("use_scratch", 1));
}

if (trans_a) {
if (trans_b) {
// A'B':
Expand Down
10 changes: 1 addition & 9 deletions caffe2/operators/batch_matmul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ class BatchMatMulOp final : public Operator<Context> {
: Operator<Context>(operator_def, ws),
trans_a_(this->template GetSingleArgument<int>("trans_a", 0)),
trans_b_(this->template GetSingleArgument<int>("trans_b", 0)),
broadcast_(this->template GetSingleArgument<int>("broadcast", 0)),
use_scratch_(this->template GetSingleArgument<int>("use_scratch", 0)) {
if (use_scratch_) {
scratch_ = std::make_shared<Tensor>(Context::GetDeviceType());
}
}
broadcast_(this->template GetSingleArgument<int>("broadcast", 0)) {}

~BatchMatMulOp() {}

Expand Down Expand Up @@ -280,9 +275,6 @@ class BatchMatMulOp final : public Operator<Context> {
bool trans_a_;
bool trans_b_;
bool broadcast_;

bool use_scratch_;
std::shared_ptr<Tensor> scratch_;
};

} // namespace caffe2
Expand Down
7 changes: 5 additions & 2 deletions cmake/public/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ endmacro()
# target name is given by the first argument and the rest are the source files
# to build the target.
function(caffe2_binary_target target_name_or_src)
if (${ARGN})
# https://cmake.org/cmake/help/latest/command/function.html
# Checking that ARGC is greater than # is the only way to ensure
# that ARGV# was passed to the function as an extra argument.
if (ARGC GREATER 1)
set(__target ${target_name_or_src})
prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${ARGN}")
else()
Expand All @@ -110,7 +113,7 @@ function(caffe2_binary_target target_name_or_src)
endfunction()

function(caffe2_hip_binary_target target_name_or_src)
if (${ARGN})
if (ARGC GREATER 1)
set(__target ${target_name_or_src})
prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${ARGN}")
else()
Expand Down