Skip to content

Commit

Permalink
Add batchinv function
Browse files Browse the repository at this point in the history
* Implement getri_batched and getrf_batched interfaces for cublas

* [WIP] Implement inverse kernel

* Disable half of inverse

* Fix types at cublas

* Fix cuda helper for inverse

* Implement inverse function

* Add casting around inverse

* Fix inverse

* Fix inv function

* Rename inverse to batch_inv

* Refactor F.batch_inv

* Auto-format

* Merge branch 'master' into feature/20191121-batchinv-seno
  • Loading branch information
AkioHayakawa-sony authored and TakuyaNarihira committed Dec 2, 2019
1 parent 77bc900 commit 2cbb4ea
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 0 deletions.
3 changes: 3 additions & 0 deletions build-tools/code_generator/function_types.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ Sort:
Reshape:
float: [float]
half: [Half]
BatchInv:
float: [float]
# half: [Half]
MatrixDiag:
float: [float]
half: [Half]
Expand Down
8 changes: 8 additions & 0 deletions include/nbla/cuda/cublas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,13 @@ void cublas_gemm_strided_batched(cublasHandle_t handle, cublasOperation_t op_x,
float alpha, const T *x, int lda, int stride_a,
const T *y, int ldb, int stride_b, float beta,
T *z, int ldc, int stride_c, int batchCount);

template <typename T>
void cublas_getrf_batched(cublasHandle_t handle, int n, T **x, int lda,
int *pivot, int *info, int batchSize);

template <typename T>
void cublas_getri_batched(cublasHandle_t handle, int n, const T **x, int lda,
int *pivot, T **y, int ldc, int *info, int batchSize);
}
#endif
48 changes: 48 additions & 0 deletions include/nbla/cuda/function/batch_inv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/** Inverse
*/
#ifndef __NBLA_CUDA_FUNCTION_INVERSE_HPP__
#define __NBLA_CUDA_FUNCTION_INVERSE_HPP__

#include <nbla/cuda/common.hpp>
#include <nbla/cuda/cuda.hpp>
#include <nbla/function/batch_inv.hpp>
namespace nbla {
/** @copydoc BatchInv
*/

template <typename T> class BatchInvCuda : public BatchInv<T> {

public:
typedef typename CudaType<T>::type Tc;
explicit BatchInvCuda(const Context &ctx)
: BatchInv<T>(ctx), device_(std::stoi(ctx.device_id)) {}
virtual ~BatchInvCuda() {}
virtual string name() { return "BatchInvCuda"; }
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cuda>()->array_classes();
}
virtual bool grad_depends_output_data(int i, int o) const { return true; }

protected:
int device_;
int dim_, batch_size_;
virtual void setup_impl(const Variables &inputs, const Variables &outputs);
virtual void forward_impl(const Variables &inputs, const Variables &outputs);
};
}

#endif
21 changes: 21 additions & 0 deletions include/nbla/cuda/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,5 +143,26 @@ void cuda_gemm_strided_batched(int device, T *z, bool transpose_z, const T *x,
_RCC(y), row_y, row_y * col_y, beta, _RC(z), m, m * n, batch_count);
}
#endif

template <typename T>
void cuda_getrf_batched(int device, int n, T **x, int *pivot, int *info,
int batchSize) {
_TD();
cublasHandle_t handle = SingletonManager::get<Cuda>()->cublas_handle(device);
// optimizing lda leaves for future improvement
cublas_getrf_batched<Tc>(handle, n, reinterpret_cast<Tc **>(x), n, pivot,
info, batchSize);
}

template <typename T>
void cuda_getri_batched(int device, int n, const T **x, int *pivot, T **y,
int *info, int batchSize) {
_TD();
cublasHandle_t handle = SingletonManager::get<Cuda>()->cublas_handle(device);
// optimizing lda and ldc leaves for future improvement
cublas_getri_batched<Tc>(handle, n, reinterpret_cast<const Tc **>(x), n,
pivot, reinterpret_cast<Tc **>(y), n, info,
batchSize);
}
}
#endif
37 changes: 37 additions & 0 deletions src/nbla/cuda/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,4 +319,41 @@ void cublas_gemm_strided_batched<half>(
}
}
#endif // CUDA_VERSION >= 8000

// ----------------------------------------------------------------------
// Getrf batched
// ----------------------------------------------------------------------
template <>
void cublas_getrf_batched<double>(cublasHandle_t handle, int n, double **x,
int lda, int *pivot, int *info,
int batchSize) {
NBLA_CUBLAS_CHECK(
cublasDgetrfBatched(handle, n, x, lda, pivot, info, batchSize));
}
template <>
void cublas_getrf_batched<float>(cublasHandle_t handle, int n, float **x,
int lda, int *pivot, int *info,
int batchSize) {
NBLA_CUBLAS_CHECK(
cublasSgetrfBatched(handle, n, x, lda, pivot, info, batchSize));
}

// ----------------------------------------------------------------------
// Getri batched
// ----------------------------------------------------------------------
template <>
void cublas_getri_batched<double>(cublasHandle_t handle, int n,
const double **x, int lda, int *pivot,
double **y, int ldc, int *info,
int batchSize) {
NBLA_CUBLAS_CHECK(
cublasDgetriBatched(handle, n, x, lda, pivot, y, ldc, info, batchSize));
}
template <>
void cublas_getri_batched<float>(cublasHandle_t handle, int n, const float **x,
int lda, int *pivot, float **y, int ldc,
int *info, int batchSize) {
NBLA_CUBLAS_CHECK(
cublasSgetriBatched(handle, n, x, lda, pivot, y, ldc, info, batchSize));
}
}
78 changes: 78 additions & 0 deletions src/nbla/cuda/function/generic/batch_inv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <nbla/array.hpp>
#include <nbla/cuda/array/cuda_array.hpp>
#include <nbla/cuda/common.hpp>
#include <nbla/cuda/function/batch_inv.hpp>
#include <nbla/cuda/math.hpp>
#include <nbla/function/batch_matmul.hpp>
#include <nbla/variable.hpp>

namespace nbla {

// Sets head pointers of matrices in mini-batch.
template <typename T>
__global__ void kernel_set_batch_pointers(int batchSize, int n, const T **ptr,
const T *head) {
NBLA_CUDA_KERNEL_LOOP(idx, batchSize) { ptr[idx] = head + idx * n * n; }
}

// A macro that creates an array of pointers of matrices.
#define NBLA_GET_BATCH_POINTERS(PTR, NAME, BATCH, CONST) \
CudaCachedArray list_##PTR(sizeof(Tc *) * BATCH, dtypes::BYTE, this->ctx_); \
CONST Tc **dev_list_##NAME = \
reinterpret_cast<CONST Tc **>(list_##PTR.pointer<void>()); \
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_set_batch_pointers, BATCH, dim_, \
(const T **)dev_list_##NAME, (const T *)PTR)

// ----------------------------------------------------------------------
// With cublas<t>getrfBatched and cublas<t>getriBatched
// ----------------------------------------------------------------------
template <typename T>
void BatchInvCuda<T>::setup_impl(const Variables &inputs,
const Variables &outputs) {
BatchInv<T>::setup_impl(inputs, outputs);
batch_size_ = inputs[0]->shape()[0];
dim_ = inputs[0]->shape()[1];
}

template <typename T>
void BatchInvCuda<T>::forward_impl(const Variables &inputs,
const Variables &outputs) {
cuda_set_device(this->device_);
const Tc *x = inputs[0]->get_data_pointer<Tc>(this->ctx_);
Tc *y = outputs[0]->cast_data_and_get_pointer<Tc>(this->ctx_, true);

CudaCachedArray pivot(dim_ * batch_size_, dtypes::INT, this->ctx_);
CudaCachedArray info(batch_size_, dtypes::INT, this->ctx_);
CudaCachedArray lu(inputs[0]->size(), get_dtype<Tc>(), this->ctx_);

lu.copy_from(inputs[0]->data()->cast(get_dtype<Tc>(), this->ctx_, false));

Tc *lu_ptr = lu.pointer<Tc>();
NBLA_GET_BATCH_POINTERS(lu_ptr, lu, batch_size_, ) // dev_list_lu
NBLA_GET_BATCH_POINTERS(y, y, batch_size_, ); // dev_list_y

// LU factorization
cuda_getrf_batched<Tc>(this->device_, dim_, dev_list_lu, pivot.pointer<int>(),
info.pointer<int>(), batch_size_);

// matrix inversion
cuda_getri_batched<Tc>(this->device_, dim_, (const Tc **)dev_list_lu,
pivot.pointer<int>(), dev_list_y, info.pointer<int>(),
batch_size_);
}
}

0 comments on commit 2cbb4ea

Please sign in to comment.