Skip to content

Commit

Permalink
Cherry-pick large tensor support from apache#18752. (apache#18804)
Browse files Browse the repository at this point in the history
Co-authored-by: Joe Evans <joeev@amazon.com>
  • Loading branch information
josephevans and Joe Evans committed Jul 29, 2020
1 parent 126636c commit e9829e7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ List of Contributors
* [Connor Goggins](https://github.com/connorgoggins)
* [Wei Chu](https://github.com/waytrue17)
* [Yang Shi](https://github.com/ys2843)
* [Joe Evans](https://github.com/josephevans)

Label Bot
---------
Expand Down
11 changes: 6 additions & 5 deletions src/operator/tensor/la_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ using namespace mshadow;
// Copies lower/upper triangular part to upper/lower, i.e. to the opposite side.
struct CopyTriangularToOppositeSide {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data, bool to_lower) {
MSHADOW_XINLINE static void Map(index_t i, size_t matrix_size, index_t stride,
DType* data, bool to_lower) {
// Below computation works even when we are dealing with a batch of matrices.
const int row((i % matrix_size) / stride), col(i % stride);
const index_t row((i % matrix_size) / stride), col(i % stride);
if (row > col) {
if (to_lower) {
data[i] = data[i + (col - row) * (stride - 1)];
Expand All @@ -52,9 +53,9 @@ struct CopyTriangularToOppositeSide {
// Zero's lower/upper triangular part of a matrix.
struct ZeroTriangular {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data,
bool zero_lower) {
const int row((i % matrix_size) / stride), col(i % stride);
MSHADOW_XINLINE static void Map(index_t i, size_t matrix_size, index_t stride,
DType* data, bool zero_lower) {
const index_t row((i % matrix_size) / stride), col(i % stride);
if ((!zero_lower && (row < col)) || (zero_lower && (row > col))) data[i] = 0;
}
};
Expand Down

0 comments on commit e9829e7

Please sign in to comment.