Skip to content

Commit

Permalink
Merge pull request #819 from sony/feature/20210302-fix-async-ones-zeros
Browse files Browse the repository at this point in the history
Fix async copy in NNabla::ones and zeros
  • Loading branch information
TakuyaNarihira committed Mar 8, 2021
2 parents c98aceb + a99fff1 commit 04e028f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
16 changes: 12 additions & 4 deletions src/nbla/singleton_manager.cpp
Expand Up @@ -57,6 +57,14 @@ NNabla::NNabla() {}

NNabla::~NNabla() {}

const void *async_get(const shared_ptr<SyncedArray> &arr, dtypes dtype,
const Context &ctx) {
auto ret = arr->get(dtype, ctx, AsyncFlag::ASYNC)->const_pointer<void>();
arr->get(dtype, ctx); // Workaraound to wait async copy. call get again by the
// same dtype and ctx.
return ret;
}

const void *NNabla::ones(Size_t size, dtypes dtype, const Context &ctx) {
auto tid = std::this_thread::get_id();
shared_ptr<SyncedArray> ones;
Expand All @@ -66,15 +74,15 @@ const void *NNabla::ones(Size_t size, dtypes dtype, const Context &ctx) {
ones = std::make_shared<SyncedArray>(size);
ones->fill(1);
ones_[tid] = ones;
return ones->get(dtype, ctx, AsyncFlag::ASYNC)->const_pointer<void>();
return async_get(ones, dtype, ctx);
}
ones = it->second;
if (size > ones->size()) {
ones = std::make_shared<SyncedArray>(size);
ones->fill(1);
ones_[tid] = ones;
}
return ones->get(dtype, ctx, AsyncFlag::ASYNC)->const_pointer<void>();
return async_get(ones, dtype, ctx);
}

const void *NNabla::zeros(Size_t size, dtypes dtype, const Context &ctx) {
Expand All @@ -86,15 +94,15 @@ const void *NNabla::zeros(Size_t size, dtypes dtype, const Context &ctx) {
zeros = std::make_shared<SyncedArray>(size);
zeros->zero();
ones_[tid] = zeros;
return zeros->get(dtype, ctx, AsyncFlag::ASYNC)->const_pointer<void>();
return async_get(zeros, dtype, ctx);
}
zeros = it->second;
if (size > zeros->size()) {
zeros = std::make_shared<SyncedArray>(size);
zeros->zero();
ones_[tid] = zeros;
}
return zeros->get(dtype, ctx, AsyncFlag::ASYNC)->const_pointer<void>();
return async_get(zeros, dtype, ctx);
}

NBLA_INSTANTIATE_SINGLETON(NBLA_API, NNabla);
Expand Down
4 changes: 1 addition & 3 deletions src/nbla/synced_array.cpp
Expand Up @@ -161,9 +161,7 @@ SyncedArray::ArrayDesc SyncedArray::sync(dtypes dtype, const Context &ctx_orig,
shared_ptr<Array>(ArrayCreator::create(size_, dtype, ctx)), false);
} else {
// Wait for the end of previous async_flags asynchronous memcpy
if (!(async_flags & AsyncFlag::ASYNC)) {
array_[desc.key].first->wait_event(ctx, async_flags);
}
array_[desc.key].first->wait_event(ctx, async_flags);
}

if (write_only) {
Expand Down

0 comments on commit 04e028f

Please sign in to comment.