From fc1153a8be7cb1d4e7d1219746dc744cb8e28b4f Mon Sep 17 00:00:00 2001 From: Meghan Lele Date: Wed, 2 Dec 2020 12:28:09 -0800 Subject: [PATCH] [JIT] Fix clang-tidy warnings in jit/runtime (#47992) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47992 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D25258645 Pulled By: SplitInfinity fbshipit-source-id: b3e4576400c101b247e80cb4044fc04471f39a47 --- torch/csrc/jit/runtime/graph_executor.cpp | 10 ++++---- torch/csrc/jit/runtime/graph_executor.h | 2 +- torch/csrc/jit/runtime/interpreter.cpp | 2 +- torch/csrc/jit/runtime/interpreter.h | 2 +- torch/csrc/jit/runtime/logging.h | 6 ++--- torch/csrc/jit/runtime/operator.h | 8 +++---- torch/csrc/jit/runtime/register_ops_utils.h | 2 +- torch/csrc/jit/runtime/register_prim_ops.cpp | 2 +- .../jit/runtime/register_prim_ops_fulljit.cpp | 2 +- torch/csrc/jit/runtime/vararg_functions.cpp | 24 ++++++++++++------- torch/csrc/jit/runtime/vararg_functions.h | 12 +++++++--- 11 files changed, 41 insertions(+), 31 deletions(-) diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index c85bfb3169a8..7e258b576f96 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -131,7 +131,7 @@ struct CaptureList { auto tensors = val.toTensorList(); sizes_.push_back(tensors.size()); - for (const at::Tensor& tensor : tensors) { + for (const at::Tensor tensor : tensors) { captureTensor(tensor, is_output); } } else { @@ -269,7 +269,7 @@ struct DifferentiableGraphBackward : public autograd::Node { size_t output_index = 0; for (IValue& v : stack) { if (v.isTensorList()) { - for (const at::Tensor& tensor : v.toTensorList()) { + for (at::Tensor tensor : v.toTensorList()) { produceOutput(output_index++, std::move(tensor), outputs); } } else if (v.isTensor()) { @@ -295,7 +295,7 @@ struct DifferentiableGraphBackward : public autograd::Node { } void addOutputForIValue(const IValue& value) { if (value.isTensorList()) { - for (const at::Tensor& tensor : value.toTensorList()) { + for (const at::Tensor tensor : value.toTensorList()) { addOutputForTensor(tensor); } } else { @@ -319,7 +319,7 @@ struct DifferentiableGraphBackward : public autograd::Node { if (v.isTensorList()) { auto tensors = v.toTensorList(); input_instructions_.pushTensorList(tensors.size()); - for (const at::Tensor& tensor : tensors) { + for (const at::Tensor tensor : tensors) { addInputVariable(tensor); } } else if (v.isTensor()) { @@ -719,7 +719,7 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { }; GraphExecutor::GraphExecutor( - std::shared_ptr graph, + const std::shared_ptr& graph, std::string function_name) : pImpl( IsNewExecutorEnabled() diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index 5de2cd6f89f7..1b938c187648 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -55,7 +55,7 @@ struct TORCH_API EnableProfilingGuard { struct GraphExecutorImplBase; struct TORCH_API GraphExecutor { GraphExecutor() = default; - GraphExecutor(std::shared_ptr graph, std::string function_name); + GraphExecutor(const std::shared_ptr& graph, std::string function_name); void run(Stack& inputs); c10::intrusive_ptr runAsync( diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 7c487c5546b0..4f6fc77da260 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -706,7 +706,7 @@ struct CodeImpl { void emitCall(Function* func, at::ArrayRef inputs) { emitLoadInputs(inputs); insertInstruction(CALL, function_table_.size()); - function_table_.emplace_back(std::move(func)); + function_table_.emplace_back(func); } void emitNodeAtBlockLevel(Node* node) { diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 9b40aed75d42..025ac67f6f2e 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -102,7 +102,7 @@ struct Suspend : public std::exception { // thread local settings are propagated with ThreadLocalState struct InterpreterContinuation { InterpreterContinuation( - InterpreterState state_, + const InterpreterState& state_, Stack stack_, int64_t dist_autograd_context_id = 0, c10::optional tls_state = c10::nullopt) diff --git a/torch/csrc/jit/runtime/logging.h b/torch/csrc/jit/runtime/logging.h index f5f4559e65f4..ce0339410f89 100644 --- a/torch/csrc/jit/runtime/logging.h +++ b/torch/csrc/jit/runtime/logging.h @@ -16,7 +16,7 @@ class LoggerBase { TORCH_API virtual void addStatValue( const std::string& stat_name, int64_t val) = 0; - virtual ~LoggerBase() {} + virtual ~LoggerBase() = default; }; TORCH_API LoggerBase* getLogger(); @@ -28,7 +28,7 @@ TORCH_API LoggerBase* setLogger(LoggerBase* logger); class NoopLogger : public LoggerBase { public: void addStatValue(const std::string& stat_name, int64_t val) override {} - ~NoopLogger() {} + ~NoopLogger() = default; }; // Trivial locking logger. Pass in an instance of this to setLogger() to use it. @@ -42,7 +42,7 @@ class TORCH_API LockingLogger : public LoggerBase { virtual int64_t getCounterValue(const std::string& name) const; enum class AggregationType { SUM = 0, AVG = 1 }; void setAggregationType(const std::string& stat_name, AggregationType type); - ~LockingLogger() {} + ~LockingLogger() = default; private: mutable std::mutex m; diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index e0528afa27b3..ec3aa3d4fb99 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -73,7 +73,7 @@ struct TORCH_API Operator { public: Operator(c10::OperatorHandle opHandle, Operation operation) : op_(c10::make_left( - C10Operator{std::move(opHandle), std::move(operation)})) {} + C10Operator{opHandle, std::move(operation)})) {} Operator( std::string schema, @@ -102,8 +102,7 @@ struct TORCH_API Operator { : op_(c10::make_right(JitOnlyOperator{ c10::make_right( UnparsedFunctionSchema{std::move(schema), alias_analysis}), - c10::make_right( - std::move(op_creator))})) {} + c10::make_right(op_creator)})) {} // Helper constructor to register `op` to run // run for _every_ IR Node where n.kind() == name, regardless of arguments. @@ -116,8 +115,7 @@ struct TORCH_API Operator { : op_(c10::make_right(JitOnlyOperator{ c10::make_left( varArgSchemaWithName(name, alias_analysis)), - c10::make_right( - std::move(op_creator))})) {} + c10::make_right(op_creator)})) {} Operation getOperation(const Node* node = nullptr) const { return op_.fold( diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index 2f72ffad3302..42464cecd89a 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -179,7 +179,7 @@ void setItem(const c10::List& list, int64_t idx, T&& value) { if (normalized_idx < 0 || normalized_idx >= list_size) { throw std::out_of_range("list index out of range"); } - list.set(normalized_idx, std::move(value)); + list.set(normalized_idx, std::forward(value)); } void listAppend(Stack* stack); diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 4b3fd88c00c7..f031d957449b 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -1347,7 +1347,7 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size) { int64_t stringFindImpl( std::string string, - std::string substr, + const std::string& substr, int64_t start, int64_t end, bool reverse = false) { diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index bacda807f7f2..68b9b54dd42c 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -348,7 +348,7 @@ RegisterOperators reg( at::infer_size(size, peek(stack, i, num_inputs).toIntVector()); } drop(stack, num_inputs); - push(stack, IValue(std::move(size))); + push(stack, IValue(size)); }, aliasAnalysisSpecialCase()), Operator( diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp index f2679e32bf76..e61676b83eca 100644 --- a/torch/csrc/jit/runtime/vararg_functions.cpp +++ b/torch/csrc/jit/runtime/vararg_functions.cpp @@ -204,7 +204,10 @@ void namedTupleConstruct( c10::ivalue::Tuple::createNamed(std::move(elems), std::move(type))); } -void listConstruct(Stack& stack, at::ListTypePtr type, size_t num_inputs) { +void listConstruct( + Stack& stack, + const at::ListTypePtr& type, + size_t num_inputs) { c10::List vals(type->getElementType()); vals.reserve(num_inputs); for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) { @@ -214,7 +217,10 @@ void listConstruct(Stack& stack, at::ListTypePtr type, size_t num_inputs) { push(stack, std::move(vals)); } -void dictConstruct(Stack& stack, at::DictTypePtr type, size_t num_inputs) { +void dictConstruct( + Stack& stack, + const at::DictTypePtr& type, + size_t num_inputs) { at::TypePtr key_type = type->getKeyType(); at::TypePtr value_type = type->getValueType(); auto vals = c10::impl::GenericDict(key_type, value_type); @@ -231,7 +237,7 @@ void dictConstruct(Stack& stack, at::DictTypePtr type, size_t num_inputs) { push(stack, std::move(vals)); } -void createObject(Stack& stack, at::ClassTypePtr type) { +void createObject(Stack& stack, const at::ClassTypePtr& type) { auto userObj = c10::ivalue::Object::create( c10::StrongTypePtr(type->compilation_unit(), type), type->numAttributes()); @@ -267,19 +273,19 @@ void dequantize(Stack& stack) { auto elems = tuple->elements(); std::vector output_elems; output_elems.reserve(elems.size()); - for (size_t i = 0; i < elems.size(); ++i) { - if (elems[i].isTensor()) { - output_elems.emplace_back(at::dequantize(elems[i].toTensor())); + for (const auto& elem : elems) { + if (elem.isTensor()) { + output_elems.emplace_back(at::dequantize(elem.toTensor())); } else { - output_elems.emplace_back(elems[i]); + output_elems.emplace_back(elem); } } push(stack, c10::ivalue::Tuple::create(std::move(output_elems))); } else if (iv.isTensorList()) { auto elems = iv.toTensorList(); auto output_list = c10::impl::GenericList(elems.elementType()); - for (size_t i = 0; i < elems.size(); ++i) { - output_list.emplace_back(at::dequantize(elems[i])); + for (auto&& elem : elems) { + output_list.emplace_back(at::dequantize(elem)); } push(stack, std::move(output_list)); } else { diff --git a/torch/csrc/jit/runtime/vararg_functions.h b/torch/csrc/jit/runtime/vararg_functions.h index 301dde436a18..36bef721d626 100644 --- a/torch/csrc/jit/runtime/vararg_functions.h +++ b/torch/csrc/jit/runtime/vararg_functions.h @@ -22,11 +22,17 @@ void namedTupleConstruct( at::TupleTypePtr type, size_t num_inputs); -void listConstruct(Stack& stack, at::ListTypePtr list_type, size_t num_inputs); +void listConstruct( + Stack& stack, + const at::ListTypePtr& list_type, + size_t num_inputs); -void dictConstruct(Stack& stack, at::DictTypePtr type, size_t num_inputs); +void dictConstruct( + Stack& stack, + const at::DictTypePtr& type, + size_t num_inputs); -void createObject(Stack& stack, at::ClassTypePtr type); +void createObject(Stack& stack, const at::ClassTypePtr& type); void isinstance(Stack& stack, at::ArrayRef types);