Skip to content

Commit

Permalink
Xが4D、Yが3D、axisが0の場合に対応
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 23, 2018
1 parent 94e7037 commit 9244e0c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
6 changes: 5 additions & 1 deletion python/test/function/test_broadcast_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def ref_broadcast_to(x, y, axis):
elif axis == 1:
t = y[np.newaxis, :, :, np.newaxis]
return np.broadcast_to(t, x.shape)
elif ys == 3:
if axis == 0:
t = y[:, :, :, np.newaxis]
return np.broadcast_to(t, x.shape)



Expand All @@ -91,7 +95,7 @@ def ref_broadcast_to(x, y, axis):
((2, 3, 4, 5), (2, 3), 0),
((2, 3, 4, 5), (3, 4), 1),
((2, 3, 4, 5), (4, 5), 2),
#((2, 3, 4, 5), (2, 3, 4), 0),
((2, 3, 4, 5), (2, 3, 4), 0),
((2, 3, 4, 5), (3, 4, 5), 1),
((2, 3, 4, 5), (5), -1),
((2, 3, 4, 5), (4, 5), -1),
Expand Down
22 changes: 21 additions & 1 deletion src/nbla/function/broadcast_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,27 @@ void BroadcastTo<T>::forward_impl(const Variables &inputs, const Variables &outp
case 3:
switch(axis_) {
case 0:
// X: (2,3,4,5) Y: (2,3,4) axis=0
{
// X: (2,3,4,5) Y: (2,3,4) axis=0
Size_t y_buf_width = ys[2];
Size_t y_slice_size = ys[1]*y_buf_width;
Size_t z_row_width = xs[3];
Size_t z_slice_size = xs[2]*xs[3];
Size_t z_chan_size = xs[1]*z_slice_size;
for (Size_t n=0; n<xs[0]; n++) {
const T* ychan = &y[n*y_slice_size];
T* zslice = &z[n*z_chan_size];
for (Size_t c=0; c<xs[1]; c++) {
const T* yrow = &ychan[c*y_buf_width];
T* zblock = &zslice[c*z_slice_size];
for (Size_t h=0; h<xs[2]; h++) {
T val = yrow[h];
T* zrow = &zblock[h*z_row_width];
std::fill(zrow, zrow+z_row_width, val);
}
}
}
}
break;
case 1:
// X: (2,3,4,5) Y: (3,4,5) axis=1
Expand Down

0 comments on commit 9244e0c

Please sign in to comment.