Skip to content

Commit

Permalink
Update base for Update on "Remove hacky double registration of to_her…
Browse files Browse the repository at this point in the history
…e op in reg_distributed_ops"

This was added as a part of
#38590 but we can use default arguments
here. We use fmt:;format to bind the default value to the rpc timeout at
runtime.

Differential Revision: [D21912719](https://our.internmc.facebook.com/intern/diff/D21912719/)

[ghstack-poisoned]
  • Loading branch information
rohan-varma committed Jun 15, 2020
2 parents c30f69d + c8c53c8 commit 132a4c9
Show file tree
Hide file tree
Showing 37 changed files with 657 additions and 394 deletions.
4 changes: 3 additions & 1 deletion .jenkins/caffe2/test.sh
Expand Up @@ -140,7 +140,9 @@ pip install --user pytest-sugar
# torchvision tests #
#####################
if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
pip install -q --user git+https://github.com/pytorch/vision.git@ba63fbdb595f41901f074883abc0084145877cf5
# Check out torch/vision at Jun 11 2020 commit
# This hash must match one in .jenkins/pytorch/test.sh
pip install -q --user git+https://github.com/pytorch/vision.git@c2e8a00885e68ae1200eb6440f540e181d9125de
pip install -q --user ninja
# JIT C++ extensions require ninja, so put it into PATH.
export PATH="/var/lib/jenkins/.local/bin:$PATH"
Expand Down
4 changes: 3 additions & 1 deletion .jenkins/pytorch/test.sh
Expand Up @@ -173,7 +173,9 @@ test_aten() {
}

test_torchvision() {
pip_install --user git+https://github.com/pytorch/vision.git@43e94b39bcdda519c093ca11d99dfa2568aa7258
# Check out torch/vision at Jun 11 2020 commit
# This hash must match one in .jenkins/caffe2/test.sh
pip_install --user git+https://github.com/pytorch/vision.git@c2e8a00885e68ae1200eb6440f540e181d9125de
}

test_libtorch() {
Expand Down
13 changes: 8 additions & 5 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -291,11 +291,15 @@ void bernoulli_tensor_kernel(Tensor& self, const Tensor& p_, c10::optional<Gener
templates::cpu::bernoulli_kernel(self, p_, generator);
}

#if !AT_MKL_ENABLED()
void bernoulli_scalar_kernel(Tensor& self, double p, c10::optional<Generator> gen) {
void bernoulli_scalar_kernel_default(Tensor& self, double p, c10::optional<Generator> gen) {
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
templates::cpu::bernoulli_kernel(self, p, generator);
}

#if !AT_MKL_ENABLED()
void bernoulli_scalar_kernel(Tensor& self, double p, c10::optional<Generator> gen) {
bernoulli_scalar_kernel_default(self, p, gen);
}
#else
void bernoulli_scalar_kernel(Tensor &self, double p, c10::optional<Generator> gen) {
if (cpuinfo_initialize() && cpuinfo_vendor_intel == cpuinfo_get_processor(0)->core->vendor) {
Expand Down Expand Up @@ -348,9 +352,8 @@ void bernoulli_scalar_kernel(Tensor &self, double p, c10::optional<Generator> ge
}
});
} else {
// Use AT_ASSERTM because this should never be reached, and AT_ASSERTM tells
// users to report this as a bug.
AT_ASSERTM(false, "ATen not compiled with MKL");
// The situation of AMD, move to using the default version
bernoulli_scalar_kernel_default(self, p, gen);
}
}
#endif
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cuda/Loops.cuh
Expand Up @@ -90,6 +90,7 @@ void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
using arg2_t = typename traits::template arg<1>::type;
auto a = iter.scalar_value<arg1_t>(1);
iter.remove_operand(1);
const OptionalDeviceGuard device_guard(device_of(iter.tensor(1)));
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
return f(a, b);
});
Expand Down
8 changes: 3 additions & 5 deletions aten/src/ATen/native/cuda/UnaryFractionKernels.cu
Expand Up @@ -7,7 +7,6 @@
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/zmath.cuh>

namespace at { namespace native {

Expand Down Expand Up @@ -63,7 +62,7 @@ __host__ __device__ static inline scalar_t reciprocal_wrapper(scalar_t a) {
}

template<typename T>
__host__ __device__ static inline thrust::complex<T> reciprocal_wrapper(thrust::complex<T> v) {
__host__ __device__ static inline c10::complex<T> reciprocal_wrapper(c10::complex<T> v) {
// Handle extreme cases for numpy compatibility
auto both_inf = [](T real, T imag) {
return (::isinf(real) && ::isinf(imag));
Expand All @@ -84,15 +83,14 @@ __host__ __device__ static inline thrust::complex<T> reciprocal_wrapper(thrust::
// If either is Inf, return {0, 0}
return {0, 0};
}
const thrust::complex<T> one = thrust::complex<T>(1.0, 0);
const c10::complex<T> one = c10::complex<T>(1.0, 0);
return one/v;
}

void reciprocal_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "reciprocal_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "reciprocal_cuda", [&] {
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
gpu_kernel(iter, []GPU_LAMBDA(thrust_t a) -> thrust_t {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return reciprocal_wrapper(a);
});
});
Expand Down
122 changes: 0 additions & 122 deletions aten/src/THC/THCTensorScatterGather.cu
Expand Up @@ -42,40 +42,6 @@ struct IndexToScatterGatherOffsets {
}
};

// Same as above but using a dynamic number of dimensions.
template <typename IndexType, typename Real>
struct IndexToScatterGatherOffsets<IndexType, Real, -1> {
static __device__ void compute(
IndexType linearId, const int dim,
const TensorInfo<int64_t, IndexType>& index, IndexType* indexOffset,
const TensorInfo<Real, IndexType>& t1, IndexType* t1Offset,
const TensorInfo<Real, IndexType>& t2, IndexType* t2Offset) {
for (int d = index.dims - 1; d >= 0; d--) {
IndexType curDimIndex = linearId % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
linearId /= index.sizes[d];
}
}

static __device__ void compute(
IndexType linearId, const int dim,
const TensorInfo<int64_t, IndexType>& index, IndexType* indexOffset,
const TensorInfo<Real, IndexType>& t2, IndexType* t2Offset) {
for (int d = index.dims - 1; d >= 0; d--) {
IndexType curDimIndex = linearId % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
linearId /= index.sizes[d];
}
}
};

template <typename IndexType, typename Real, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
Expand Down Expand Up @@ -106,94 +72,6 @@ __global__ void THCudaTensor_gatherKernel(
}
}

template <typename IndexType, typename Real, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void THCudaTensor_scatterKernel(
TensorInfo<Real, IndexType> tensor,
TensorInfo<Real, IndexType> src,
TensorInfo<int64_t, IndexType> index,
const int dim,
const IndexType totalElements) {
for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
linearId < totalElements;
linearId += gridDim.x * blockDim.x) {
IndexType tensorOffset = 0;
IndexType srcOffset = 0;
IndexType indexOffset = 0;

IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim,
index, &indexOffset,
src, &srcOffset,
tensor, &tensorOffset);

int64_t indexValue = index.data[indexOffset];
CUDA_KERNEL_ASSERT(indexValue >= 0 && indexValue < tensor.sizes[dim]);
tensorOffset += indexValue * tensor.strides[dim];

tensor.data[tensorOffset] = src.data[srcOffset];
}
}

template <typename IndexType, typename Real, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void THCudaTensor_scatterAddKernel(
TensorInfo<Real, IndexType> tensor,
TensorInfo<Real, IndexType> src,
TensorInfo<int64_t, IndexType> index,
const int dim,
const IndexType totalElements) {
for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
linearId < totalElements;
linearId += gridDim.x * blockDim.x) {
IndexType tensorOffset = 0;
IndexType srcOffset = 0;
IndexType indexOffset = 0;

IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim,
index, &indexOffset,
src, &srcOffset,
tensor, &tensorOffset);

int64_t indexValue = index.data[indexOffset];
CUDA_KERNEL_ASSERT(indexValue >= 0 && indexValue < tensor.sizes[dim]);
tensorOffset += indexValue * tensor.strides[dim];

gpuAtomicAdd(&tensor.data[tensorOffset], src.data[srcOffset]);
}
}

template <typename IndexType, typename Real, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void THCudaTensor_scatterFillKernel(
TensorInfo<Real, IndexType> tensor,
TensorInfo<int64_t, IndexType> index,
Real value,
const int dim,
const IndexType totalElements) {
for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
linearId < totalElements;
linearId += gridDim.x * blockDim.x) {
IndexType tensorOffset = 0;
IndexType indexOffset = 0;

IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim,
index, &indexOffset,
tensor, &tensorOffset);

int64_t indexValue = index.data[indexOffset];
CUDA_KERNEL_ASSERT(indexValue >= 0 && indexValue < tensor.sizes[dim]);
tensorOffset += indexValue * tensor.strides[dim];

tensor.data[tensorOffset] = value;
}
}

#include <THC/generic/THCTensorScatterGather.cu>
#include <THC/THCGenerateAllTypes.h>

Expand Down
105 changes: 75 additions & 30 deletions caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
Expand Up @@ -4,52 +4,97 @@ namespace caffe2 {

template <>
template <typename T>
void LayerNormFakeFp16Op<CPUContext>::ComputeSigmaAndFusedParams(
const int M,
const float eps,
const T* mean,
const T* var,
T* sigma,
T* scale,
T* bias) {
ConstEigenVectorArrayMap<T> var_arr(sigma, M);
EigenVectorArrayMap<T> sigma_arr(sigma, M);
sigma_arr = var_arr + static_cast<T>(eps);
math::Rsqrt<T, CPUContext>(M, sigma, scale, &context_);
math::Mul<T, CPUContext>(M, scale, sigma, sigma, &context_);
EigenVectorArrayMap<T>(bias, M) = -ConstEigenVectorArrayMap<T>(scale, M) *
ConstEigenVectorArrayMap<T>(mean, M);
void LayerNormFakeFp16Op<CPUContext>::fp16_wrap(T* tmp) {
fbgemm::RoundToFloat16(tmp, tmp, 1, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
}

template <>
template <typename T>
void LayerNormFakeFp16Op<CPUContext>::LayerNormForward(
void LayerNormFakeFp16Op<CPUContext>::calcY(
const int M,
const int N,
const T* X,
const T* scale,
const T* bias,
const T* mean,
const T* std,
const T* gamma,
const T* beta,
T* Y) {
ConstEigenArrayMap<T> X_arr(X, N, M);
ConstEigenVectorArrayMap<T> scale_arr(scale, M);
ConstEigenVectorArrayMap<T> bias_arr(bias, M);
ConstEigenVectorArrayMap<T> mean_arr(mean, M);
ConstEigenVectorArrayMap<T> std_arr(std, M);
EigenArrayMap<T> Y_arr(Y, N, M);
T tmp = T(0);

if (gamma != nullptr && beta != nullptr) {
ConstEigenVectorArrayMap<T> gamma_arr(gamma, N);
ConstEigenVectorArrayMap<T> beta_arr(beta, N);
Y_arr = (((X_arr.rowwise() * scale_arr.transpose()).rowwise() +
bias_arr.transpose())
.colwise() *
gamma_arr)
.colwise() +
beta_arr;

for (int i = 0; i < M; ++i) {
T normFactor = T(T(1) / std_arr[i]);
fp16_wrap(&normFactor);
for (int j = 0; j < N; ++j) {
tmp = T(X_arr.col(i)[j] - mean[i]);
fp16_wrap(&tmp);
T normalized = tmp * normFactor;
fp16_wrap(&normalized);
tmp = normalized * gamma_arr[j];
fp16_wrap(&tmp);
tmp = tmp + beta_arr[j];
fp16_wrap(&tmp);
Y_arr.col(i)[j] = tmp;
}
}
} else {
CAFFE_ENFORCE(gamma == nullptr);
CAFFE_ENFORCE(beta == nullptr);
Y_arr = (X_arr.rowwise() * scale_arr.transpose()).rowwise() +
bias_arr.transpose();
for (int i = 0; i < M; ++i) {
T normFactor = T(T(1) / std_arr[i]);
fp16_wrap(&normFactor);
for (int j = 0; j < N; ++j) {
tmp = T(X_arr.col(i)[j] - mean[i]);
fp16_wrap(&tmp);
tmp *= normFactor;
fp16_wrap(&tmp);
Y_arr.col(i)[j] = tmp;
}
}
}
}

template <>
template <typename T>
void LayerNormFakeFp16Op<CPUContext>::calcMeanStd(
const int M,
const int N,
const float eps,
const T* X,
T* mean,
T* std) {
ConstEigenArrayMap<T> X_arr(X, N, M);

T sqr[M];
T var[M];
T inv_N_val = T(1) / N;
T tmp = T(0);

for (int i = 0; i < M; ++i) {
mean[i] = T(0);
sqr[i] = T(0);
var[i] = T(0);
for (int j = 0; j < N; ++j) {
tmp = T(X_arr.col(i)[j] * inv_N_val);
fp16_wrap(&tmp);
mean[i] += tmp;
fp16_wrap(&mean[i]);
tmp *= X_arr.col(i)[j];
fp16_wrap(&tmp);
sqr[i] += tmp;
fp16_wrap(&sqr[i]);
}
tmp = mean[i] * mean[i];
fp16_wrap(&tmp);
var[i] = sqr[i] - tmp;
fp16_wrap(&var[i]);
std[i] = std::sqrt(var[i] + eps);
fp16_wrap(&std[i]);
}
}

Expand Down

0 comments on commit 132a4c9

Please sign in to comment.