Skip to content

Commit

Permalink
Merge pull request #392 from TylunasLi/bug_fix_att_opt
Browse files Browse the repository at this point in the history
修复非batch下CPU Attention算子取batch错误(#385)
  • Loading branch information
ztxz16 committed Dec 26, 2023
2 parents 4271f82 + c2fb785 commit 46a9918
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,9 @@ namespace fastllm {
float *vd = (float*)v.cpuData;
float *maskd = (datas.find("mask")->second && mask.dims.size() > 0) ? (float*)mask.cpuData : nullptr;
float *od = (float*)output.cpuData;
int batch = intParams.find("q___batch")->second;
int maskStride = (datas.find("mask")->second) ? (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)) : 0;
int batch = (maskd != nullptr && mask.dims.size() == 3) ? mask.dims[0] : 1;
batch = intParams.find("mask___batch") != intParams.end() ? intParams.find("mask___batch")->second : batch;
int maskStride = (maskd != nullptr) ? (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)) : 0;
std::fill(od, od + output.Count(0), 0.0f);
auto pool = GetPool();
std::vector<std::future<void> > futures;
Expand Down

0 comments on commit 46a9918

Please sign in to comment.