Skip to content

Commit

Permalink
Merge branch 'feature/20180401-file-format-converter' of ml-git.ubiq.…
Browse files Browse the repository at this point in the history
…sony.co.jp:nnabla/nnabla into feature/20180401-file-format-converter
  • Loading branch information
YukioOobuchi committed Jun 11, 2018
2 parents 6aafe0a + 275a8f9 commit afec556
Show file tree
Hide file tree
Showing 21 changed files with 6,045 additions and 588 deletions.
3 changes: 3 additions & 0 deletions build-tools/code_generator/function_types.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ Transpose:
Broadcast:
float: [float]
half: [Half]
BroadcastTo:
float: [float]
half: [Half]
OneHot:
float: [int, float]
half: [int, Half]
Expand Down
18 changes: 18 additions & 0 deletions build-tools/code_generator/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2160,6 +2160,24 @@ Array Manipulation:
outputs:
y:
doc: Broadcasted N-D array
BroadcastTo:
snake_name: broadcast_to
doc: |2
Broadcasting ND-array to the specified buffer.
inputs:
x:
doc: N-D array
y:
doc: N-D array
arguments:
axis:
doc: Target axis to start broadcasting. If this is not set, broadcast will try to fit y to x starting from the last dimension
type: int64
default: -1
outputs:
z:
doc: Broadcasted N-D array
OneHot:
snake_name: one_hot
doc: |2
Expand Down
44 changes: 44 additions & 0 deletions doc/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3652,6 +3652,50 @@ Broadcasting ND-array to the specified shape.
- Broadcasted N-D array
-

BroadcastTo
^^^^^^^^^^^

Broadcasting ND-array to the specified buffer

* Input(s)

.. list-table::

* - Name
- Description
- Options
* - x
- N-D array
-
* - y
- N-D array
-

* Argument(s)

.. list-table::

* - Name
- Type
- Default
- Description
* - axis
- int64
- -1
- Target axis to start broadcasting. If this is not set, broadcast will try to fit y to x starting from the last dimension


* Output(s)

.. list-table::

* - Name
- Description
- Options
* - z
- Broadcasted N-D array
-

OneHot
^^^^^^

Expand Down
71 changes: 71 additions & 0 deletions include/nbla/function/broadcast_to.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// *WARNING*
// THIS FILE IS AUTO-GENERATED DUMMY CODE BY CODE GENERATOR.
// PLEASE IMPLEMENT REAL CODE AND DELETE THIS MESSAGE SOON.
// If you want to change dummy code, edit following files.
// - build-tools/code_generator/function_generator/generate_include_nbla_function_hpp.py
// - build-tools/code_generator/templates/include_nbla_function_hpp_template.hpp


/** BroadcastTo
*/
#ifndef __NBLA_FUNCTION_BROADCASTTO_HPP__
#define __NBLA_FUNCTION_BROADCASTTO_HPP__

#include <nbla/cpu.hpp>
#include <nbla/function.hpp>
#include <nbla/function_registry.hpp>

namespace nbla {

NBLA_REGISTER_FUNCTION_HEADER(BroadcastTo, int);

/**
@todo PLACE HERE FUNCTION DOCUMENTATION.
*/
template <typename T>
class BroadcastTo : public BaseFunction<int> {
protected:
int axis_;

public:
BroadcastTo(const Context &ctx, int axis) : BaseFunction<int>(ctx, axis), axis_(axis) {}
virtual ~BroadcastTo() {}
virtual shared_ptr<Function> copy() const {
return create_BroadcastTo(ctx_, axis_);
}
virtual vector<dtypes> in_types() {
return vector<dtypes>{ get_dtype<T>(), get_dtype<T>() };
}
virtual vector<dtypes> out_types() {
return vector<dtypes>{ get_dtype<T>() };
}
virtual int min_inputs() { return 2; }
virtual int min_outputs() { return 1; }
virtual string name() { return "BroadcastTo"; }
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cpu>()->array_classes();
}

protected:
NBLA_API virtual void setup_impl(const Variables &inputs, const Variables &outputs);
NBLA_API virtual void forward_impl(const Variables &inputs, const Variables &outputs);
NBLA_API virtual void backward_impl(const Variables &inputs, const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum);
};
}
#endif
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
'scikit-image',
'scipy',
'tqdm',
'onnx',
]


Expand Down

0 comments on commit afec556

Please sign in to comment.