Skip to content

Commit

Permalink
BroadcastToのガワとテストのガワを追加
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 19, 2018
1 parent f6a7d1f commit b301174
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 58 deletions.
4 changes: 2 additions & 2 deletions doc/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3514,8 +3514,8 @@ Broadcasting ND-array to the specified buffer
- Default
- Description
* - axis
- Shape
-
- int64
- -1
- Target axis to start broadcasting. If this is not set, broadcast will try to fit y to x starting from the last dimension


Expand Down
8 changes: 4 additions & 4 deletions include/nbla/function/broadcast_to.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@

namespace nbla {

NBLA_REGISTER_FUNCTION_HEADER(BroadcastTo, const vector<int> &);
NBLA_REGISTER_FUNCTION_HEADER(BroadcastTo, int);

/**
@todo PLACE HERE FUNCTION DOCUMENTATION.
*/
template <typename T>
class BroadcastTo : public BaseFunction<const vector<int> &> {
class BroadcastTo : public BaseFunction<int> {
protected:
const vector<int> axis_;
int axis_;

public:
BroadcastTo(const Context &ctx, const vector<int> & axis) : BaseFunction<const vector<int> &>(ctx, axis), axis_(axis) {}
BroadcastTo(const Context &ctx, int axis) : BaseFunction<int>(ctx, axis), axis_(axis) {}
virtual ~BroadcastTo() {}
virtual shared_ptr<Function> copy() const {
return create_BroadcastTo(ctx_, axis_);
Expand Down
96 changes: 46 additions & 50 deletions python/test/function/test_broadcast_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#import pytest
#import numpy as np
#import nnabla as nn
#import nnabla.functions as F
#
#from nbla_test_utils import (
# function_tester,
# list_ctx_and_func_name)
#
#
#def ref_broadcast(x, shape):
# return x * np.ones(shape, dtype=x.dtype)
#
#
#def get_combination(n):
# if n == 0:
# return [(n, np.array([], dtype=np.bool))]
# all_comb = np.vstack(map(lambda x: x.flatten(), np.meshgrid(
# *[[0, 1] for _ in range(n)]))).T.astype(np.bool)
# return [(n, comb) for comb in all_comb]
#
#
#def get_combinations(*N):
# ret = []
# for n in N:
# ret.extend(get_combination(n))
# return ret
#
#
#@pytest.mark.parametrize("seed", [314])
#@pytest.mark.parametrize("fname, ctx, func_name", list_ctx_and_func_name(['broadcast']))
#@pytest.mark.parametrize("ndim, broadcast_dim", get_combinations(*range(0, 6)))
#def test_broadcast_forward_backward(ndim, broadcast_dim, seed, fname, ctx, func_name):
# func = getattr(F, fname)
# ref_func = eval('ref_' + fname)
# rng = np.random.RandomState(seed)
# shape = rng.randint(2, 5, size=(ndim,))
# inshape = shape.copy()
# inshape[broadcast_dim] = 1
# if np.prod(inshape) == 1:
# # Performing 0-dim array test too.
# inputs = [np.array(rng.randn())]
# function_tester(rng, func, ref_func, inputs, [shape],
# ctx=ctx, backward=[True], func_name=func_name,
# atol_b=4e-3)
#
# inputs = [np.array(rng.randn(*inshape))]
# function_tester(rng, func, ref_func, inputs, [shape],
# ctx=ctx, backward=[True], func_name=func_name,
# atol_b=4e-3)
import pytest
import numpy as np
import nnabla as nn
import nnabla.functions as F
import pdb

from nbla_test_utils import (
function_tester,
list_ctx_and_func_name)

def ref_broadcast_to(x, y, axis):
return x


PARAMS = [
#((2, 3, 4, 5), (5), -1),
((2, 3, 4, 5), (4, 5), -1),
#((2, 3, 4, 5), (3, 4), 1),
#((2, 3, 4, 5), (2), 0),
]

@pytest.mark.parametrize("seed", [314])
@pytest.mark.parametrize("fname, ctx, func_name", list_ctx_and_func_name(['broadcast_to']))
@pytest.mark.parametrize("xs, ys, axis", PARAMS)
def test_broadcast_to_forward_backward(xs, ys, axis, seed, fname, ctx, func_name):
rng = np.random.RandomState(seed)
ref_func = eval('ref_' + fname)
func = getattr(F, fname)
inputs = [rng.randn(*xs), rng.randn(*ys)]
function_tester(rng, func, ref_func, inputs, [axis],
ctx=ctx, func_name=func_name,
atol_b=4e-3)
#shape = rng.randint(2, 5, size=(ndim,))
#inshape = shape.copy()
#inshape[broadcast_dim] = 1
#if np.prod(inshape) == 1:
# # Performing 0-dim array test too.
# inputs = [np.array(rng.randn())]
# function_tester(rng, func, ref_func, inputs, [shape],
# ctx=ctx, backward=[True], func_name=func_name,
# atol_b=4e-3)
#inputs = [np.array(rng.randn(*inshape))]
#function_tester(rng, func, ref_func, inputs, [shape],
# ctx=ctx, backward=[True], func_name=func_name,
# atol_b=4e-3)
pass
37 changes: 35 additions & 2 deletions src/nbla/function/broadcast_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,44 @@

namespace nbla {

NBLA_REGISTER_FUNCTION_SOURCE(BroadcastTo, const vector<int> &);
NBLA_REGISTER_FUNCTION_SOURCE(BroadcastTo, int);

template <typename T>
void BroadcastTo<T>::setup_impl(const Variables &inputs, const Variables &outputs) {
// TODO TEMPLATE CODE
const Shape_t xs = inputs[0]->shape();
const Shape_t ys = inputs[1]->shape();
const int xss = xs.size();
const int yss = ys.size();
NBLA_CHECK(xss >= yss, error_code::value,
"BroadcastTo expects Y (variable to be broadcasted) to be smaller than or equal to X (target variable we want to fit to): %d vs %d",
yss, xss);
if (axis_ < 0) {
// No axis was specified.
// Check if y shape can fit x shape from the tail dimension
const int xofs = xss - yss;
for (int i=yss-1; i>=0; i--) {
Size_t xds = xs[xofs+i];
Size_t yds = ys[i];
NBLA_CHECK(xds == yds, error_code::value,
"Dimension %d's size of X and Y do not match: %d vs %d",
xofs+i, xds, yds);
}
} else {
NBLA_CHECK(axis_ < xss, error_code::value,
"Specified axis index %d must be within the size of the actual input dimension %d",
axis_, xss);
// Check if y shape can fit x shape from the axis index
for (int i=0; i<yss; i++) {
Size_t xds = xs[i+axis_];
Size_t yds = ys[i];
NBLA_CHECK(xds == yds, error_code::value,
"Dimension %d's size of X and Y do not match: %d vs %d",
i+axis_, xds, yds);
}
}
// All check passed.
// Reshape output to fit X.
outputs[0]->reshape(xs, true);
}

template <typename T>
Expand Down

0 comments on commit b301174

Please sign in to comment.