-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement getri_batched and getrf_batched interfaces for cublas * [WIP] Implement inverse kernel * Disable half of inverse * Fix types at cublas * Fix cuda helper for inverse * Implement inverse function * Add casting around inverse * Fix inverse * Fix inv function * Rename inverse to batch_inv * Refactor F.batch_inv * Auto-format * Merge branch 'master' into feature/20191121-batchinv-seno
- Loading branch information
1 parent
77bc900
commit 2cbb4ea
Showing
6 changed files
with
195 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (c) 2017 Sony Corporation. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
/** Inverse | ||
*/ | ||
#ifndef __NBLA_CUDA_FUNCTION_INVERSE_HPP__ | ||
#define __NBLA_CUDA_FUNCTION_INVERSE_HPP__ | ||
|
||
#include <nbla/cuda/common.hpp> | ||
#include <nbla/cuda/cuda.hpp> | ||
#include <nbla/function/batch_inv.hpp> | ||
namespace nbla { | ||
/** @copydoc BatchInv | ||
*/ | ||
|
||
template <typename T> class BatchInvCuda : public BatchInv<T> { | ||
|
||
public: | ||
typedef typename CudaType<T>::type Tc; | ||
explicit BatchInvCuda(const Context &ctx) | ||
: BatchInv<T>(ctx), device_(std::stoi(ctx.device_id)) {} | ||
virtual ~BatchInvCuda() {} | ||
virtual string name() { return "BatchInvCuda"; } | ||
virtual vector<string> allowed_array_classes() { | ||
return SingletonManager::get<Cuda>()->array_classes(); | ||
} | ||
virtual bool grad_depends_output_data(int i, int o) const { return true; } | ||
|
||
protected: | ||
int device_; | ||
int dim_, batch_size_; | ||
virtual void setup_impl(const Variables &inputs, const Variables &outputs); | ||
virtual void forward_impl(const Variables &inputs, const Variables &outputs); | ||
}; | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright (c) 2017 Sony Corporation. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <iostream> | ||
#include <nbla/array.hpp> | ||
#include <nbla/cuda/array/cuda_array.hpp> | ||
#include <nbla/cuda/common.hpp> | ||
#include <nbla/cuda/function/batch_inv.hpp> | ||
#include <nbla/cuda/math.hpp> | ||
#include <nbla/function/batch_matmul.hpp> | ||
#include <nbla/variable.hpp> | ||
|
||
namespace nbla { | ||
|
||
// Sets head pointers of matrices in mini-batch. | ||
template <typename T> | ||
__global__ void kernel_set_batch_pointers(int batchSize, int n, const T **ptr, | ||
const T *head) { | ||
NBLA_CUDA_KERNEL_LOOP(idx, batchSize) { ptr[idx] = head + idx * n * n; } | ||
} | ||
|
||
// A macro that creates an array of pointers of matrices. | ||
#define NBLA_GET_BATCH_POINTERS(PTR, NAME, BATCH, CONST) \ | ||
CudaCachedArray list_##PTR(sizeof(Tc *) * BATCH, dtypes::BYTE, this->ctx_); \ | ||
CONST Tc **dev_list_##NAME = \ | ||
reinterpret_cast<CONST Tc **>(list_##PTR.pointer<void>()); \ | ||
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_set_batch_pointers, BATCH, dim_, \ | ||
(const T **)dev_list_##NAME, (const T *)PTR) | ||
|
||
// ---------------------------------------------------------------------- | ||
// With cublas<t>getrfBatched and cublas<t>getriBatched | ||
// ---------------------------------------------------------------------- | ||
template <typename T> | ||
void BatchInvCuda<T>::setup_impl(const Variables &inputs, | ||
const Variables &outputs) { | ||
BatchInv<T>::setup_impl(inputs, outputs); | ||
batch_size_ = inputs[0]->shape()[0]; | ||
dim_ = inputs[0]->shape()[1]; | ||
} | ||
|
||
template <typename T> | ||
void BatchInvCuda<T>::forward_impl(const Variables &inputs, | ||
const Variables &outputs) { | ||
cuda_set_device(this->device_); | ||
const Tc *x = inputs[0]->get_data_pointer<Tc>(this->ctx_); | ||
Tc *y = outputs[0]->cast_data_and_get_pointer<Tc>(this->ctx_, true); | ||
|
||
CudaCachedArray pivot(dim_ * batch_size_, dtypes::INT, this->ctx_); | ||
CudaCachedArray info(batch_size_, dtypes::INT, this->ctx_); | ||
CudaCachedArray lu(inputs[0]->size(), get_dtype<Tc>(), this->ctx_); | ||
|
||
lu.copy_from(inputs[0]->data()->cast(get_dtype<Tc>(), this->ctx_, false)); | ||
|
||
Tc *lu_ptr = lu.pointer<Tc>(); | ||
NBLA_GET_BATCH_POINTERS(lu_ptr, lu, batch_size_, ) // dev_list_lu | ||
NBLA_GET_BATCH_POINTERS(y, y, batch_size_, ); // dev_list_y | ||
|
||
// LU factorization | ||
cuda_getrf_batched<Tc>(this->device_, dim_, dev_list_lu, pivot.pointer<int>(), | ||
info.pointer<int>(), batch_size_); | ||
|
||
// matrix inversion | ||
cuda_getri_batched<Tc>(this->device_, dim_, (const Tc **)dev_list_lu, | ||
pivot.pointer<int>(), dev_list_y, info.pointer<int>(), | ||
batch_size_); | ||
} | ||
} |