Skip to content

Commit

Permalink
Fix and optimize handling of vectorized memory accesses (apache#17767)
Browse files Browse the repository at this point in the history
* Vectorized loads for binary elemwise kernel

* More generalization

* Add backwardusenone

* Remove the unused _backward_add op

* Add vectorized backwardusein

* Extending vectorization to more binary ops, binary ops with scalar and
unary ops

* Handling ElementwiseSum

* Get rid of half2 in mshadow

* Remove backward_elemwiseaddex

* Revert "Remove the unused _backward_add op"

This reverts commit f86da86.

* Revert "Remove backward_elemwiseaddex"

This reverts commit 7729114.

* Add back the backward_add since C++ test relies on it

* Test bcast implementations

* First version of vecotrized bcast

* Adding single side vectorized bcast kernel

* Removing debug prints

* Actually run the single side kernel

* Move the default implementation of bcast to the vectorized one

* Limit the new implementation to GPU only

* Enabling vectorization when broadcast does not actually do broadcast

* Cleaning

* Cleaning part 2

* Fix for numpy ops using stuff from broadcast

* Fix

* Fix lint

* Try to debug pinv numpy test

* Fix

* Fix the vectorized broadcast implementation for misaligned input
pointers

* Added tests

* Added docs to cuda_vectorization.cuh

* Another fix for broadcast and fix INT64 compilation

* Optimize for aligned=true

* 1 more addition to test

* Reverting the change to Numpy op test

* Trying mcmodel=medium to fix the failure in CMake static build

* Revert "Trying mcmodel=medium to fix the failure in CMake static build"

This reverts commit 1af684c.

* Limiting the PR to just elementwise ops
  • Loading branch information
ptrendx committed Apr 18, 2020
1 parent 814530d commit c5893e3
Show file tree
Hide file tree
Showing 19 changed files with 1,342 additions and 443 deletions.
48 changes: 0 additions & 48 deletions 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ extern "C" {
}

#include "./half.h"
#include "./half2.h"
#include "./bfloat.h"
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \
MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \
Expand Down Expand Up @@ -391,11 +390,6 @@ struct DataType<half::half_t> {
#endif
};
template<>
struct DataType<half::half2_t> {
static const int kFlag = kFloat16;
static const int kLanes = 2;
};
template<>
struct DataType<bfloat::bf16_t> {
static const int kFlag = kBfloat16;
static const int kLanes = 1;
Expand Down Expand Up @@ -1148,48 +1142,6 @@ struct minimum {
}
#endif

#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half2_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
Expand Down
143 changes: 0 additions & 143 deletions 3rdparty/mshadow/mshadow/half2.h

This file was deleted.

Loading

0 comments on commit c5893e3

Please sign in to comment.