Skip to content

Commit

Permalink
Auto Format
Browse files Browse the repository at this point in the history
  • Loading branch information
YukioOobuchi committed Apr 2, 2018
1 parent 2521261 commit 8ea543d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
2 changes: 1 addition & 1 deletion python/src/nnabla/utils/converter/nnablart/nnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, nnp, batch_size):

self._argument_formats = {}
for fn, func in self._info._function_info.items():
if 'arguments' in func and len(func['arguments']) > 0:
if 'arguments' in func and len(func['arguments']) > 0:
argfmt = ''
for an, arg in func['arguments'].items():
if arg['type'] == 'bool':
Expand Down
35 changes: 17 additions & 18 deletions src/nbla/function/generic/global_average_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.


#include <nbla/array.hpp>
#include <nbla/common.hpp>
#include <nbla/function/global_average_pooling.hpp>
Expand All @@ -22,15 +21,13 @@
#include <iostream>
#include <typeinfo>



namespace nbla {

NBLA_REGISTER_FUNCTION_SOURCE(GlobalAveragePooling);

template <typename T>
void GlobalAveragePooling<T>::setup_impl(const Variables &inputs,
const Variables &outputs) {
const Variables &outputs) {
// TODO: Remove debug message
std::cout << "GlobalAveragePooling<" << typeid(T).name()
<< ">::setup_impl called with " << this->ctx_.to_string() << "."
Expand All @@ -41,14 +38,15 @@ void GlobalAveragePooling<T>::setup_impl(const Variables &inputs,
/* TODO: Any preparation comes here.
Note that, although it is called only when a compuation graph is
constructed in a static computation graph, in a dynamic computation graph,
it's called every time. Keep the setup computation light for the performance
it's called every time. Keep the setup computation light for the
performance
(caching heavy computation, device synchronization in GPU etc.)
*/
}

template <typename T>
void GlobalAveragePooling<T>::forward_impl(const Variables &inputs,
const Variables &outputs) {
const Variables &outputs) {
// TODO: Remove debug message
std::cout << "GlobalAveragePooling<" << typeid(T).name()
<< ">::forward_impl called with " << this->ctx_.to_string() << "."
Expand All @@ -58,29 +56,29 @@ void GlobalAveragePooling<T>::forward_impl(const Variables &inputs,
The type `Variables` is a typedef of `vector<Variable*>`.
The `Variable` class owns storages of data (storage for forward propagation)
and grad (for backprop) respectively.
You can get a raw pointer of a scalar type of the storage using:
- `cosnt T* Variable::get_{data|grad}_pointer<T>(ctx)` for read-only access.
- `T* Variable::cast_{data|grad}_and_get_pointer<T>(ctx)` for r/w access.
By this, automatic type conversion would occur if data was held in a different type.
By this, automatic type conversion would occur if data was held in a
different type.
*/
// Inputs
const T* x = inputs[0]->get_data_pointer<T>(this->ctx_);
const T *x = inputs[0]->get_data_pointer<T>(this->ctx_);

// Outputs
T* y = outputs[0]->cast_data_and_get_pointer<T>(this->ctx_);
T *y = outputs[0]->cast_data_and_get_pointer<T>(this->ctx_);

// TODO: Write implementation
}


template <typename T>
void GlobalAveragePooling<T>::backward_impl(const Variables &inputs,
const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum) {
const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum) {
// TODO: Remove debug message
std::cout << "GlobalAveragePooling<" << typeid(T).name()
<< ">::backward_impl called with " << this->ctx_.to_string() << "."
Expand All @@ -89,18 +87,19 @@ void GlobalAveragePooling<T>::backward_impl(const Variables &inputs,
/* TODO: remove this help message.
The propagate down flags are automatically set by our graph engine, which
specifies whether each input variable of them requires gradient
computation.
computation.
*/
if (!(propagate_down[0])) {
return;
}

/** TODO: remove this help message.
The backward error signals are propagated through the graph, and the
error from decsendant functions are set in the grad region of the output variables.
error from decsendant functions are set in the grad region of the output
variables.
*/
// Gradient of outputs
const T* g_y = outputs[0]->get_grad_pointer<T>(this->ctx_);
const T *g_y = outputs[0]->get_grad_pointer<T>(this->ctx_);

/* TODO: remove this help message.
The backward error signal should be propagated to the grad region of input
Expand All @@ -111,7 +110,7 @@ void GlobalAveragePooling<T>::backward_impl(const Variables &inputs,
by substitution or accumulation.
*/
// Gradient of inputs
T* g_x{nullptr};
T *g_x{nullptr};

if (propagate_down[0]) {
g_x = inputs[0]->cast_grad_and_get_pointer<T>(this->ctx_);
Expand Down
2 changes: 0 additions & 2 deletions src/nbla/function/global_average_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ namespace nbla {
// float
template class GlobalAveragePooling<float>;


// half
template class GlobalAveragePooling<Half>;

}

0 comments on commit 8ea543d

Please sign in to comment.