Skip to content

Commit

Permalink
BroadcastToでaxisが指定されてない場合のforwardを実装
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 19, 2018
1 parent b301174 commit 36f502e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 25 deletions.
27 changes: 8 additions & 19 deletions python/test/function/test_broadcast_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
list_ctx_and_func_name)

def ref_broadcast_to(x, y, axis):
return x
if axis < 0:
return np.ones(x.shape) * y
else:
return np.ones(x.shape)


PARAMS = [
#((2, 3, 4, 5), (5), -1),
((2, 3, 4, 5), (5), -1),
((2, 3, 4, 5), (4, 5), -1),
#((2, 3, 4, 5), (3, 4), 1),
#((2, 3, 4, 5), (2), 0),
Expand All @@ -40,21 +43,7 @@ def test_broadcast_to_forward_backward(xs, ys, axis, seed, fname, ctx, func_name
rng = np.random.RandomState(seed)
ref_func = eval('ref_' + fname)
func = getattr(F, fname)
inputs = [rng.randn(*xs), rng.randn(*ys)]
inputs = [rng.random_sample(xs), rng.random_sample(ys)]
function_tester(rng, func, ref_func, inputs, [axis],
ctx=ctx, func_name=func_name,
atol_b=4e-3)
#shape = rng.randint(2, 5, size=(ndim,))
#inshape = shape.copy()
#inshape[broadcast_dim] = 1
#if np.prod(inshape) == 1:
# # Performing 0-dim array test too.
# inputs = [np.array(rng.randn())]
# function_tester(rng, func, ref_func, inputs, [shape],
# ctx=ctx, backward=[True], func_name=func_name,
# atol_b=4e-3)
#inputs = [np.array(rng.randn(*inshape))]
#function_tester(rng, func, ref_func, inputs, [shape],
# ctx=ctx, backward=[True], func_name=func_name,
# atol_b=4e-3)
pass
backward=[False,False],
ctx=ctx, func_name=func_name)
37 changes: 31 additions & 6 deletions src/nbla/function/broadcast_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ template <typename T>
void BroadcastTo<T>::setup_impl(const Variables &inputs, const Variables &outputs) {
const Shape_t xs = inputs[0]->shape();
const Shape_t ys = inputs[1]->shape();
const int xss = xs.size();
const int yss = ys.size();
const Size_t xss = xs.size();
const Size_t yss = ys.size();
NBLA_CHECK(xss >= yss, error_code::value,
"BroadcastTo expects Y (variable to be broadcasted) to be smaller than or equal to X (target variable we want to fit to): %d vs %d",
yss, xss);
if (axis_ < 0) {
// No axis was specified.
// Check if y shape can fit x shape from the tail dimension
const int xofs = xss - yss;
for (int i=yss-1; i>=0; i--) {
const Size_t xofs = xss - yss;
for (Size_t i=yss-1; i>=0; i--) {
Size_t xds = xs[xofs+i];
Size_t yds = ys[i];
NBLA_CHECK(xds == yds, error_code::value,
Expand All @@ -56,7 +56,7 @@ void BroadcastTo<T>::setup_impl(const Variables &inputs, const Variables &output
"Specified axis index %d must be within the size of the actual input dimension %d",
axis_, xss);
// Check if y shape can fit x shape from the axis index
for (int i=0; i<yss; i++) {
for (Size_t i=0; i<yss; i++) {
Size_t xds = xs[i+axis_];
Size_t yds = ys[i];
NBLA_CHECK(xds == yds, error_code::value,
Expand All @@ -71,7 +71,32 @@ void BroadcastTo<T>::setup_impl(const Variables &inputs, const Variables &output

template <typename T>
void BroadcastTo<T>::forward_impl(const Variables &inputs, const Variables &outputs) {
// TEMPLATE CODE
const T *y = inputs[1]->get_data_pointer<T>(this->ctx_);
T *z = outputs[0]->cast_data_and_get_pointer<T>(this->ctx_);
const Shape_t xs = inputs[0]->shape();
const Shape_t ys = inputs[1]->shape();
const Size_t ysize = inputs[1]->size();
const Size_t xss = xs.size();
const Size_t yss = ys.size();
if (xss == yss) {
// X and Y have exactly the same shape
// Copy Y to Z
std::copy(y, y+ysize, z);
return;
}
if (axis_ < 0) {
// copy the whole Y block to Z per stride
const Size_t diff = xss - yss;
Size_t loop_num = 1;
for (Size_t i=0; i<diff; i++) {
loop_num *= xs[i];
}
for (Size_t i=0; i<loop_num; i++) {
std::copy(y, y+ysize, z+i*ysize);
}
} else {

}
}

template <typename T>
Expand Down

0 comments on commit 36f502e

Please sign in to comment.