Skip to content

Commit 0bc1814

Browse files
committed
Update
[ghstack-poisoned]
2 parents 2fca582 + 9056fd9 commit 0bc1814

File tree

86 files changed

+1846
-468
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+1846
-468
lines changed

.ci/docker/common/install_acl.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
set -euo pipefail
22

3-
readonly version=v23.08
3+
readonly version=v24.04
44
readonly src_host=https://review.mlplatform.org/ml
55
readonly src_repo=ComputeLibrary
66

.github/workflows/lint.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ jobs:
4343
submodules: true
4444
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
4545
script: |
46-
pip install onnx==1.16.0
47-
pip install numpy==1.26.4
4846
export ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT"
4947
.github/scripts/lintrunner.sh
5048

BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ cu_library(
663663
name = "torch_cuda",
664664
srcs = [
665665
"torch/csrc/distributed/c10d/intra_node_comm.cu",
666+
"torch/csrc/distributed/c10d/Utils.cu",
666667
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
667668
],
668669
copts = torch_cuda_half_options,
@@ -830,6 +831,7 @@ cc_library(
830831
"torch/csrc/cuda/python_nccl.cpp",
831832
"torch/csrc/cuda/nccl.cpp",
832833
"torch/csrc/distributed/c10d/intra_node_comm.cu",
834+
"torch/csrc/distributed/c10d/Utils.cu",
833835
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
834836
],
835837
)) + torch_sources,

aten/src/ATen/native/FusedAdagrad.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/native/DispatchStub.h>
4+
#include <ATen/native/FusedAdagrad.h>
5+
6+
#ifndef AT_PER_OPERATOR_HEADERS
7+
#include <ATen/Functions.h>
8+
#include <ATen/NativeFunctions.h>
9+
#else
10+
#include <ATen/ops/_fused_adagrad.h>
11+
#include <ATen/ops/_fused_adagrad_native.h>
12+
#endif
13+
namespace at {
14+
15+
namespace native {
16+
17+
void _fused_adagrad_kernel_cpu_(
18+
at::TensorList params,
19+
at::TensorList grads,
20+
at::TensorList state_sums,
21+
at::TensorList state_steps,
22+
const double lr,
23+
const double lr_decay,
24+
const double weight_decay,
25+
const double eps,
26+
const bool maximize,
27+
const c10::optional<at::Tensor>& grad_scale,
28+
const c10::optional<at::Tensor>& found_inf) {
29+
const float* grad_scale_ptr =
30+
grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
31+
const float* found_inf_ptr =
32+
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
33+
if (found_inf_ptr && *found_inf_ptr == 1.0) {
34+
return;
35+
}
36+
size_t n_tensors = params.size();
37+
TORCH_CHECK(grads.size() == n_tensors);
38+
TORCH_CHECK(state_sums.size() == n_tensors);
39+
TORCH_CHECK(state_steps.size() == n_tensors);
40+
for (size_t i = 0; i < n_tensors; i++){
41+
fused_adagrad_stub(
42+
kCPU,
43+
params[i],
44+
grads[i],
45+
state_sums[i],
46+
state_steps[i],
47+
lr,
48+
lr_decay,
49+
weight_decay,
50+
eps,
51+
maximize,
52+
grad_scale_ptr);
53+
}
54+
}
55+
56+
DEFINE_DISPATCH(fused_adagrad_stub);
57+
58+
}
59+
}

aten/src/ATen/native/FusedAdagrad.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <ATen/core/Tensor.h>
2+
#include <ATen/native/DispatchStub.h>
3+
4+
namespace at {
5+
6+
namespace native {
7+
8+
using fused_adagrad_fn = void (*)(
9+
const at::Tensor& param,
10+
const at::Tensor& grad,
11+
const at::Tensor& state_sum,
12+
const at::Tensor& state_step,
13+
const double lr,
14+
const double lr_decay,
15+
const double weight_decay,
16+
const double eps,
17+
const bool maximize,
18+
const float* grad_scale_ptr);
19+
20+
DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub);
21+
22+
}
23+
}

aten/src/ATen/native/ReduceAllOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ATen/Functions.h>
99
#include <ATen/NativeFunctions.h>
1010
#else
11+
#include <ATen/ops/_aminmax_native.h>
1112
#include <ATen/ops/aminmax.h>
1213
#include <ATen/ops/empty.h>
1314
#include <ATen/ops/max.h>
@@ -65,4 +66,11 @@ Tensor& max_unary_out(const Tensor &self, Tensor& out) {
6566
return out;
6667
}
6768

69+
// DEPRECATED: Use at::aminmax instead
70+
std::tuple<Tensor, Tensor> _aminmax_all(const Tensor &self) {
71+
TORCH_WARN_ONCE("_aminmax is deprecated as of PyTorch 1.11 and will be removed in a future release. Use aminmax instead."
72+
" This warning will only appear once per process.");
73+
return at::aminmax(self);
74+
}
75+
6876
} // namespace at::native

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <ATen/Functions.h>
2121
#include <ATen/NativeFunctions.h>
2222
#else
23+
#include <ATen/ops/_aminmax_native.h>
2324
#include <ATen/ops/_assert_async_native.h>
2425
#include <ATen/ops/_functional_assert_async_native.h>
2526
#include <ATen/ops/_print_native.h>
@@ -681,6 +682,13 @@ std::tuple<Tensor, Tensor> qmin(const Tensor& self, int64_t dim, bool keepdim) {
681682
at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
682683
}
683684

685+
// DEPRECATED: Use at::aminmax instead
686+
std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdim) {
687+
TORCH_WARN_ONCE("_aminmax is deprecated as of PyTorch 1.11 and will be removed in a future release. Use aminmax instead."
688+
" This warning will only appear once per process.");
689+
return at::aminmax(self, dim, keepdim);
690+
}
691+
684692
TORCH_IMPL_FUNC(clamp_out)
685693
(
686694
const Tensor& /*self*/,

aten/src/ATen/native/TypeProperties.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ ScalarType result_type(const Scalar& scalar1, const Scalar& scalar2) {
191191
return result_type(state);
192192
}
193193

194-
bool can_cast(const at::ScalarType from_, const at::ScalarType to) {
195-
return at::canCast(from_, to);
194+
bool can_cast(const at::ScalarType from, const at::ScalarType to) {
195+
return at::canCast(from, to);
196196
}
197197

198198
ScalarType promote_types(ScalarType type1, ScalarType type2) {
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/Parallel.h>
4+
#include <ATen/OpMathType.h>
5+
#include <ATen/native/DispatchStub.h>
6+
#include <ATen/native/FusedAdagrad.h>
7+
#include <ATen/Dispatch.h>
8+
#include <ATen/cpu/vec/vec.h>
9+
#include <ATen/cpu/vec/functional.h>
10+
namespace at::native {
11+
12+
namespace{
13+
14+
template <typename scalar_t, typename opmath_t>
15+
typename std::enable_if<
16+
std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
17+
void>::
18+
type inline adagrad_math(
19+
scalar_t* param_ptr,
20+
scalar_t* grad_ptr,
21+
scalar_t* state_sum_ptr,
22+
const double clr,
23+
const double eps,
24+
const double weight_decay,
25+
const bool maximize,
26+
const float* grad_scale_ptr,
27+
int64_t size
28+
){
29+
using lpVec = at::vec::Vectorized<scalar_t>;
30+
using fVec = at::vec::Vectorized<opmath_t>;
31+
lpVec grad_vec_to_store;
32+
fVec param_vec1, param_vec2;
33+
fVec grad_vec1, grad_vec2;
34+
fVec state_sum_vec1, state_sum_vec2;
35+
int64_t d = 0;
36+
for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
37+
lpVec param_lpvec = lpVec::loadu(param_ptr + d);
38+
std::tie(param_vec1, param_vec2) = vec::convert_to_float<scalar_t>(param_lpvec);
39+
lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
40+
std::tie(grad_vec1, grad_vec2) = vec::convert_to_float<scalar_t>(grad_lpvec);
41+
if (grad_scale_ptr) {
42+
grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
43+
grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
44+
grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
45+
grad_vec_to_store.store(grad_ptr + d);
46+
}
47+
if (maximize){
48+
grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
49+
grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
50+
}
51+
if (weight_decay != 0.0){
52+
grad_vec1 += param_vec1 * fVec(scalar_t(weight_decay));
53+
grad_vec2 += param_vec2 * fVec(scalar_t(weight_decay));
54+
}
55+
std::tie(state_sum_vec1, state_sum_vec2) = vec::convert_to_float<scalar_t>(lpVec::loadu(state_sum_ptr + d));
56+
state_sum_vec1 += grad_vec1 * grad_vec1;
57+
state_sum_vec2 += grad_vec2 * grad_vec2;
58+
vec::convert_from_float<scalar_t>(state_sum_vec1, state_sum_vec2).store(state_sum_ptr + d);
59+
60+
fVec std_vec1 = state_sum_vec1.sqrt() + fVec(scalar_t(eps));
61+
fVec std_vec2 = state_sum_vec2.sqrt() + fVec(scalar_t(eps));
62+
param_vec1 = param_vec1 - fVec(scalar_t(clr)) * grad_vec1 / std_vec1;
63+
param_vec2 = param_vec2 - fVec(scalar_t(clr)) * grad_vec2 / std_vec2;
64+
vec::convert_from_float<scalar_t>(param_vec1, param_vec2).store(param_ptr + d);
65+
}
66+
scalar_t grad_val_to_store;
67+
for (; d < size; d++) {
68+
opmath_t grad_val = grad_ptr[d];
69+
opmath_t param_val = param_ptr[d];
70+
if (grad_scale_ptr) {
71+
grad_val = grad_ptr[d] / opmath_t(*grad_scale_ptr);
72+
grad_val_to_store = grad_val;
73+
grad_ptr[d] = grad_val_to_store;
74+
}
75+
if (maximize) grad_val = -grad_val;
76+
if (weight_decay != 0.0){
77+
grad_val += param_val * opmath_t(weight_decay);
78+
}
79+
opmath_t state_sum_val = state_sum_ptr[d];
80+
state_sum_val += grad_val * grad_val;
81+
state_sum_ptr[d] = state_sum_val;
82+
opmath_t std_val = std::sqrt(state_sum_val) + opmath_t(eps);
83+
param_val -= opmath_t(clr) * grad_val / std_val;
84+
param_ptr[d] = param_val;
85+
}
86+
}
87+
88+
89+
template <typename scalar_t, typename opmath_t>
90+
typename std::enable_if<
91+
std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
92+
void>::
93+
type inline adagrad_math(
94+
scalar_t* param_ptr,
95+
scalar_t* grad_ptr,
96+
scalar_t* state_sum_ptr,
97+
const double clr,
98+
const double eps,
99+
const double weight_decay,
100+
const bool maximize,
101+
const float* grad_scale_ptr,
102+
int64_t size
103+
){
104+
using Vec = at::vec::Vectorized<scalar_t>;
105+
Vec grad_vec_to_store;
106+
int64_t d = 0;
107+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
108+
Vec param_vec = Vec::loadu(param_ptr + d);
109+
Vec grad_vec = Vec::loadu(grad_ptr + d);
110+
if (grad_scale_ptr) {
111+
grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
112+
grad_vec_to_store = grad_vec;
113+
grad_vec_to_store.store(grad_ptr + d);
114+
}
115+
if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
116+
if (weight_decay != 0.0){
117+
grad_vec += param_vec * Vec(scalar_t(weight_decay));
118+
}
119+
120+
Vec sum_vec = Vec::loadu(state_sum_ptr + d) + grad_vec * grad_vec;
121+
sum_vec.store(state_sum_ptr + d);
122+
123+
Vec std_vec = sum_vec.sqrt() + Vec(scalar_t(eps));
124+
param_vec = param_vec - Vec(scalar_t(clr)) * grad_vec / std_vec;
125+
param_vec.store(param_ptr + d);
126+
}
127+
scalar_t grad_val_to_store;
128+
for (; d < size; d++) {
129+
scalar_t grad_val = grad_ptr[d];
130+
if (grad_scale_ptr) {
131+
grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
132+
grad_val_to_store = grad_val;
133+
grad_ptr[d] = grad_val_to_store;
134+
}
135+
if (maximize) grad_val = -grad_val;
136+
if (weight_decay != 0.0){
137+
grad_val += param_ptr[d] * scalar_t(weight_decay);
138+
}
139+
state_sum_ptr[d] += grad_val * grad_val;
140+
141+
scalar_t std_val = std::sqrt(state_sum_ptr[d]) + scalar_t(eps);
142+
param_ptr[d] -= scalar_t(clr) * grad_val / std_val;
143+
}
144+
}
145+
146+
template <typename scalar_t>
147+
void adagrad_fused_step_impl(
148+
const at::Tensor& param,
149+
const at::Tensor& grad,
150+
const at::Tensor& state_sum,
151+
const at::Tensor& state_step,
152+
const double lr,
153+
const double lr_decay,
154+
const double weight_decay,
155+
const double eps,
156+
const bool maximize,
157+
const float* grad_scale_ptr) {
158+
using opmath_t = at::opmath_type<scalar_t>;
159+
scalar_t* param_data = param.data_ptr<scalar_t>();
160+
scalar_t* grad_data = grad.data_ptr<scalar_t>();
161+
scalar_t* state_sum_data = state_sum.data_ptr<scalar_t>();
162+
double step = state_step.item<float>();
163+
double clr = lr / (1.0 + (step - 1.0) * lr_decay);
164+
165+
constexpr size_t cache_line_size = 64;
166+
constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
167+
size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
168+
169+
auto adagrad_fn = [&](int64_t begin, int64_t end) {
170+
// local pointers
171+
begin *= cache_line_aligned_task_unit;
172+
end = std::min(end * cache_line_aligned_task_unit, param.numel());
173+
scalar_t* param_ptr = param_data + begin;
174+
scalar_t* grad_ptr = grad_data + begin;
175+
scalar_t* state_sum_ptr = state_sum_data + begin;
176+
177+
const int64_t size = end - begin;
178+
adagrad_math<scalar_t, opmath_t>(
179+
param_ptr,
180+
grad_ptr,
181+
state_sum_ptr,
182+
clr,
183+
eps,
184+
weight_decay,
185+
maximize,
186+
grad_scale_ptr,
187+
size
188+
);
189+
};
190+
at::parallel_for(
191+
0, num_units, 0, adagrad_fn);
192+
}
193+
194+
void fused_adagrad_kernel(
195+
const at::Tensor& param,
196+
const at::Tensor& grad,
197+
const at::Tensor& state_sum,
198+
const at::Tensor& state_step,
199+
const double lr,
200+
const double lr_decay,
201+
const double weight_decay,
202+
const double eps,
203+
const bool maximize,
204+
const float* grad_scale_ptr
205+
) {
206+
Tensor grad_contiguous = grad.contiguous();
207+
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_adagrad_kernel", [&] {
208+
adagrad_fused_step_impl<scalar_t>(
209+
param,
210+
grad,
211+
state_sum,
212+
state_step,
213+
lr,
214+
lr_decay,
215+
weight_decay,
216+
eps,
217+
maximize,
218+
grad_scale_ptr);
219+
});
220+
}
221+
222+
}
223+
224+
REGISTER_DISPATCH(fused_adagrad_stub, &fused_adagrad_kernel);
225+
} // namespace at::native

0 commit comments

Comments
 (0)