Skip to content

Commit

Permalink
Xが2次元、Yが1次元の場合の処理を追加
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 20, 2018
1 parent cf6957c commit 75c2b44
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
10 changes: 9 additions & 1 deletion python/test/function/test_broadcast_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,25 @@
function_tester,
list_ctx_and_func_name)


def copying_to_leaf(x, y, axis):
return (len(x.shape) - len(y.shape) - axis) == 0


def ref_broadcast_to(x, y, axis):
if axis < 0 or copying_to_leaf(x, y, axis):
# Copy data to leaf
return np.ones(x.shape) * y
else:
return np.ones(x.shape)
# Copy data from specified axis
if len(x.shape) == 2:
t = y[:,np.newaxis]
t.transpose()
return np.tile(t, (1, x.shape[1]))


PARAMS = [
((2, 3), (2), 0),
((2, 3), (3), 1),
((2, 3, 4), (4), 2),
((2, 3, 4), (3, 4), 1),
Expand Down
13 changes: 11 additions & 2 deletions src/nbla/function/broadcast_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,17 @@ void BroadcastTo<T>::forward_impl(const Variables &inputs, const Variables &outp
// Y dimension size is always 1
switch(axis_) {
case 0:
// X: (2,3) Y: (2) axis=0
// copy Y values vertically
{
// X: (2,3) Y: (2) axis=0
// copy Y values vertically
Size_t height = xs[0];
Size_t width = xs[1];
for (Size_t v=0; v<height; v++) {
T val = y[v];
T* zrow = &z[v*width];
std::fill(zrow, zrow+width, val);
}
}
break;
case 1:
// X: (2,3) Y: (3) axis=1
Expand Down

0 comments on commit 75c2b44

Please sign in to comment.