Skip to content

Commit

Permalink
Update on "[jit] Polymorphic IValue::type() for DynamicType."
Browse files Browse the repository at this point in the history
Before the change:
```
c10::Type t = ivalue.type();
```
After the change:
```
c10::Type t = ivalue.type();
c10::DynamicType d = ivalue.type<c10::DynamicType>(); // new path
```
The new path will be adopted in PyTorch Lite Interpreter to support lightweight type reflection. Note that type getters are selected at compile time so no performance overhead. The benefits of having a DynamicType will be elaborated in a separate document, but in short, DynamicType provides an isolated type system for controlling binary size bloat, and shrink down ~20 supported Type symbols into one so that the size taken by specializations and function name symbols are greatly reduced.

Lite Interpreter should only use the `<DynamicType>` variant of the interfaces from aten, to reduce binary size.

Differential Revision: [D33102276](https://our.internmc.facebook.com/intern/diff/D33102276/)

[ghstack-poisoned]
  • Loading branch information
zhxchen17 committed Jan 7, 2022
2 parents 6df124a + b3ad04d commit 00f1ad8
Show file tree
Hide file tree
Showing 99 changed files with 1,908 additions and 988 deletions.
4 changes: 2 additions & 2 deletions .circleci/scripts/binary_windows_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ export SCCACHE_BUCKET=ossci-compiler-cache-windows
export NIGHTLIES_PYTORCH_ROOT="$PYTORCH_ROOT"
export VC_YEAR=2019

if [[ "${DESIRED_CUDA}" == "cu111" || "${DESIRED_CUDA}" == "cu113" ]]; then
export BUILD_SPLIT_CUDA="ON"
if [[ "${DESIRED_CUDA}" == *"cu11"* ]]; then
export BUILD_SPLIT_CUDA=ON
fi

echo "Free Space for CUDA DEBUG BUILD"
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ CppFunction::CppFunction(c10::KernelFunction func, c10::optional<c10::impl::CppS
, debug_()
{}

CppFunction::~CppFunction() = default;

#define ERROR_CONTEXT "(Error occurred while processing ", toString(kind_), " block at ", file_, ":", line_, ")"

Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
Expand Down
22 changes: 9 additions & 13 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,6 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::option
const Tensor& gO_r, const Tensor& weight_r, const Tensor& input,
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32,
std::array<bool, 3> output_mask) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt);
Expand Down Expand Up @@ -1331,10 +1330,6 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::option
} else {
params.groups = groups_;
}
params.benchmark = benchmark;
params.deterministic = deterministic;
params.cudnn_enabled = cudnn_enabled;
params.allow_tf32 = allow_tf32;

// Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
Tensor ggO;
Expand All @@ -1343,14 +1338,14 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::option
if (weight.is_cuda()) {
weight = weight.contiguous();
}
ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled, params.allow_tf32);
ggO = at::convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
}

if (ggW.defined()) {
if (ggW.is_cuda()) {
ggW = ggW.contiguous();
}
auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled, params.allow_tf32);
auto ggW_term = at::convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
if (ggO.defined()) {
ggO = ggO + ggW_term;
} else {
Expand Down Expand Up @@ -1405,9 +1400,9 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::option
// Compute conv
if (params.transposed) {
gw_conv_params.transposed = false;
gWt = at::_convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32);
gWt = at::convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
} else {
gWt = at::_convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32);
gWt = at::convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
}
} else {
std::vector<Tensor> gWt_list(groups);
Expand All @@ -1421,9 +1416,9 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::option
// Compute conv
if (params.transposed) {
gw_conv_params.transposed = false;
gWt_list[g] = at::_convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32);
gWt_list[g] = at::convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
} else {
gWt_list[g] = at::_convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32);
gWt_list[g] = at::convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
}
}

Expand Down Expand Up @@ -1459,7 +1454,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::option
if (gO.is_cuda()) {
gO = gO.contiguous();
}
gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled, params.allow_tf32);
gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);

// narrow gI to only relevant portion
// we do it this way because negative output_padding is not supported
Expand Down Expand Up @@ -1493,7 +1488,8 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::option
if (gO.is_cuda()) {
gO = gO.contiguous();
}
gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled, params.allow_tf32);

gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1203,14 +1203,16 @@ void scatter_impl(
ReduceStub& reduce_stub,
FillStub& fill_stub,
const c10::optional<c10::string_view> reduce = nullopt) {
if (index.numel() == 0) return;

dim = at::maybe_wrap_dim(dim, self.dim());
auto mut_out = const_cast<Tensor&>(out);

if (!self.is_same(mut_out)) {
mut_out.copy_(self);
}

if (index.numel() == 0) return;

if (reduce.has_value()) {
auto op = meta::get_operator_enum(reduce.value());
reduce_stub(self.device().type(), mut_out, dim, index, src, op);
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/native/cpu/SerialStackImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t di
int64_t ninputs = tensors.size();
std::vector<InputMeta> inputs;
inputs.reserve(ninputs);
for (const auto i : c10::irange(tensors.size())) {
auto& tensor = tensors[i];
for (const auto& tensor : tensors) {
inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
}

Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@

- func: _convolution_mode(Tensor input, Tensor weight, Tensor? bias, int[] stride, str padding, int[] dilation, int groups) -> Tensor

- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)

- func: conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor

Expand Down Expand Up @@ -3713,6 +3713,7 @@
python_module: nn
dispatch:
MkldnnCPU: mkldnn_gelu
QuantizedCPU: gelu_quantized_cpu

- func: gelu_backward.grad_input(Tensor grad, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
structured: True
Expand Down
50 changes: 50 additions & 0 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,54 @@ static void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx,
});
}

void qgelu_kernel(const Tensor& qx, Tensor& qy) {
int64_t zero_point = qx.q_zero_point();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float scale = qx.q_scale();
auto scale_vec = Vectorized<float>(scale);
auto zero_point_vec = Vectorized<float>((float)zero_point);
auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
int64_t output_zero_point = zero_point;
float output_scale = scale;
float inv_output_scale = 1.0 / output_scale;
const auto kAlphaVec = Vectorized<float>(M_SQRT1_2);
const auto kOneVec = Vectorized<float>(1);
const auto kPointFiveVec = Vectorized<float>(0.5);

AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
output_scale,
output_zero_point,
c10::nullopt);
auto iter = TensorIterator::unary_op(qy, qx);

using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(
iter,
[&](scalar_t value_qx) -> scalar_t {
const auto value_dx =
at::native::dequantize_val(scale, zero_point, value_qx);
const auto value_dy =
value_dx * 0.5 * (1 + std::erf(value_dx * M_SQRT1_2));
return at::native::quantize_val<scalar_t>(
output_scale, output_zero_point, value_dy);
},
[&](Vec value_qx) -> Vec {
auto value_dx = value_qx.dequantize(
scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
for (auto & value : value_dx) {
value = value * kPointFiveVec * (kOneVec + (value * kAlphaVec).erf());
}
return Vec::quantize(
value_dx, output_scale, output_zero_point, inv_output_scale);
});
});
}


void qsigmoid_kernel(
const Tensor& qx, Tensor& qy, double output_scale, int64_t output_zero_point ) {
int64_t zero_point = qx.q_zero_point();
Expand Down Expand Up @@ -3467,6 +3515,7 @@ REGISTER_NO_AVX512_DISPATCH(qmul_relu_stub);
REGISTER_NO_AVX512_DISPATCH(qmul_stub);
REGISTER_NO_AVX512_DISPATCH(qrelu_leaky_stub);
REGISTER_NO_AVX512_DISPATCH(qrelu_stub);
REGISTER_NO_AVX512_DISPATCH(qgelu_stub);
REGISTER_NO_AVX512_DISPATCH(qsigmoid_stub);
REGISTER_NO_AVX512_DISPATCH(qtanh_stub);
REGISTER_NO_AVX512_DISPATCH(qthreshold_stub);
Expand Down Expand Up @@ -3518,6 +3567,7 @@ REGISTER_DISPATCH(qmul_relu_stub, &qmul_kernel<true>);
REGISTER_DISPATCH(qmul_stub, &qmul_kernel<false>);
REGISTER_DISPATCH(qrelu_leaky_stub, &leaky_qrelu_out_kernel);
REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
REGISTER_DISPATCH(qgelu_stub, &qgelu_kernel);
REGISTER_DISPATCH(qsigmoid_stub, &qsigmoid_kernel);
REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel);
REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel);
Expand Down
23 changes: 23 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qgelu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <torch/library.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/quantized/Quantizer.h>
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <c10/util/irange.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>

#include <algorithm>

namespace at {
namespace native {

DEFINE_DISPATCH(qgelu_stub);

Tensor gelu_quantized_cpu(const Tensor& qx) {
Tensor qy;
qgelu_stub(qx.device().type(), qx, qy);
return qy;
}
}} // namespace at::native
2 changes: 2 additions & 0 deletions aten/src/ATen/native/quantized/cpu/quantized_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace native {
using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
const Scalar& /*negval_*/);
using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point);
using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
using qclamp_fn = void (*)(
Expand Down Expand Up @@ -180,6 +181,7 @@ DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub);
DECLARE_DISPATCH(qnormalize_fn, quantized_normalize_stub);
DECLARE_DISPATCH(qrelu_fn, qrelu_stub);
DECLARE_DISPATCH(qrelu_leaky_fn, qrelu_leaky_stub);
DECLARE_DISPATCH(qgelu_fn, qgelu_stub);
DECLARE_DISPATCH(qsigmoid_fn, qsigmoid_stub);
DECLARE_DISPATCH(qtanh_fn, qtanh_stub);
DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub);
Expand Down
34 changes: 0 additions & 34 deletions aten/src/ATen/templates/RegisterSchema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,6 @@ namespace at {
TORCH_LIBRARY(aten, m) {
${schema_registrations};

// String Ops
// Implementations located in torch/csrc/jit/runtime/register_prim_ops.cpp
m.def(TORCH_SELECTIVE_SCHEMA("aten::splitlines(str self, bool keepends=False) -> str[]"));
m.def(TORCH_SELECTIVE_SCHEMA(
"aten::slice.str(str string, int? start=None, int? end=None, int step=1) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::isupper(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::islower(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::capitalize(str self) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::title(str self) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::center(str self, int width, str fillchar=' ') -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::count(str self, str substr, int start=0, int end=-1) -> int"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::endswith(str self, str substr, int start=0, int end=-1) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::startswith(str self, str substr, int start=0, int end=-1) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::expandtabs(str self, int tabsize=8) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::find(str self, str substr, int start=0, int end=-1) -> int"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::rfind(str self, str substr, int start=0, int end=-1) -> int"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::index.str(str self, str substr, int start=0, int end=-1) -> int"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::rindex(str self, str substr, int start=0, int end=-1) -> int"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::isidentifier(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::istitle(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::isprintable(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::ljust(str self, int width, str fillchar=' ') -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::rjust(str self, int width, str fillchar=' ') -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::zfill(str self, int width) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::lstrip(str self, str chars=' \\n\\t\\f\\v') -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::rstrip(str self, str chars=' \\n\\t\\f\\v') -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::strip(str self, str chars=' \\n\\t\\f\\v') -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::replace(str self, str old, str new, int max=-1) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::partition(str self, str separator) -> (str, str, str)"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::rpartition(str self, str separator) -> (str, str, str)"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::split.str(str self, str? separator=None, int max=-1) -> str[]"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::rsplit(str self, str separator=' ', int max=-1) -> str[]"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::join(str self, str[] values) -> str"));

// Distributed Ops
// Implementations located in torch/csrc/jit/runtime/register_distributed_ops.cpp
m.def("get_gradients(int context_id) -> Dict(Tensor, Tensor)");
Expand Down
11 changes: 6 additions & 5 deletions caffe2/operators/gather_ranges_to_dense_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,23 @@ class GatherRangesToDenseOp final : public Operator<Context> {
int rangesDataOffset = 0;
auto itemsize = data.dtype().itemsize();

auto batchSize = ranges.size(0);
const auto batchSize = ranges.size(0);
vector<int64_t> outputDims{batchSize, 0};
vector<char*> outputRawData;
outputRawData.reserve(OutputSize());
for (const auto i : c10::irange(OutputSize())) {
auto* output = Output(i);
auto *const output = Output(i);
outputDims[1] = lengths_[i];
output->Resize(outputDims);
char* ptr = static_cast<char*>(output->raw_mutable_data(data.dtype()));
char *const ptr = static_cast<char*>(output->raw_mutable_data(data.dtype()));
memset(ptr, 0, output->nbytes());
outputRawData.push_back(ptr);
}

for (const auto i : c10::irange(batchSize)) {
for (const auto j : c10::irange(OutputSize())) {
auto rangeStart = rangesData[rangesDataOffset++];
auto rangeLength = rangesData[rangesDataOffset++];
const auto rangeStart = rangesData[rangesDataOffset++];
const auto rangeLength = rangesData[rangesDataOffset++];

if (rangeLength == 0) {
// empty range, will be filled with zeros
Expand Down
5 changes: 3 additions & 2 deletions caffe2/operators/reshape_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
#include "c10/util/irange.h"

namespace caffe2 {

Expand Down Expand Up @@ -97,7 +98,7 @@ class ReshapeOp : public Operator<Context> {
}

int unknown_idx = -1;
for (int i = 0; i < actual_new_shape.size(); ++i) {
for (const auto i : c10::irange(actual_new_shape.size())) {
const auto dim = actual_new_shape[i];
if (dim == -1) {
CAFFE_ENFORCE(
Expand Down Expand Up @@ -153,7 +154,7 @@ class ReshapeOp : public Operator<Context> {
old_shape->Resize(input.sizes().size());
T* old_shape_data = old_shape->template mutable_data<T>();
std::vector<T> old_shape_vector(input.sizes().begin(), input.sizes().end());
for (int i = 0; i < old_shape_vector.size(); ++i) {
for (const auto i : c10::irange(old_shape_vector.size())) {
old_shape_data[i] = old_shape_vector[i];
}

Expand Down
3 changes: 2 additions & 1 deletion caffe2/operators/reverse_packed_segs_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "c10/util/irange.h"

namespace caffe2 {

Expand Down Expand Up @@ -62,7 +63,7 @@ class ReversePackedSegsOp final : public Operator<Context> {
context_.FinishDeviceComputation();

T* rev_data_ptr = output->template mutable_data<T>();
for (int64_t i = 0; i < batch_size; i++) {
for (const auto i : c10::irange(batch_size)) {
const auto& seg_length = lengths_host[i];
CAFFE_ENFORCE_LE(seg_length, max_length);
int64_t j = 0;
Expand Down
3 changes: 1 addition & 2 deletions caffe2/operators/rnn/recurrent_network_blob_fetcher_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ class RecurrentNetworkBlobFetcherOp final : public Operator<Context> {

std::vector<std::string> blob_names_vector = {};

// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
for (int64_t i = 0; i < stepWorkspaces.size(); i++) {
for (const auto i : c10::irange(stepWorkspaces.size())) {
Workspace* currentStepWorkspace = stepWorkspaces[i].get();
std::vector<std::string> blob_names = currentStepWorkspace->LocalBlobs();

Expand Down
Loading

0 comments on commit 00f1ad8

Please sign in to comment.