Skip to content

Commit

Permalink
Merge pull request apache#146 from sxjscience/fix_broadcast_to_k40
Browse files Browse the repository at this point in the history
Fix broadcast_to on K20/K40
  • Loading branch information
tqchen committed Jul 4, 2016
2 parents c100c92 + db6cb22 commit cf72793
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions mshadow/extension/broadcast_with_axis.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ struct BroadcastWithMultiAxesExp :
}
for (index_t i = 0; i < dimsrc; i++) {
this->shape_[i] = src_shape[i];
this->sizes_[i] = 1;
this->trailings_[i] = 1;
}
for (index_t i = 0; i < this->axesnum_; i++) {
this->shape_[axes[i]] = sizes[i];
this->sizes_[i] = sizes[i];
}
if (this->axesnum_ > 0) {
for (index_t i = 0; i < this->axesnum_; i++) {
this->trailings_[i] = 1;
for (index_t j = axes[i] + 1; j < dimsrc; ++j) {
this->trailings_[i] *= this->shape_[j];
}
for (index_t i = 0; i < this->axesnum_; i++) {
this->trailings_[i] = 1;
for (index_t j = axes[i] + 1; j < dimsrc; ++j) {
this->trailings_[i] *= this->shape_[j];
}
}
this->last_ = src_shape[dimsrc - 1];
Expand Down Expand Up @@ -237,7 +237,10 @@ struct Plan<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>, DType> {
trailings_(e.trailings_), sizes_(e.sizes_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t indx = i * dst_last_ + j;
for (index_t p = 0; p < axesnum_; ++p) {
for (index_t p = 0; p < dimsrc; ++p) {
if (p >= axesnum_) {
break;
}
indx = (indx / trailings_[p] / sizes_[p]) * trailings_[p] + (indx % trailings_[p]);
}
return src_.Eval(indx / last_, indx % last_);
Expand Down

0 comments on commit cf72793

Please sign in to comment.