Skip to content

Commit

Permalink
prompt单独处理,节约一点显存
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyuyang committed Aug 30, 2023
1 parent 8286d1d commit d7e3335
Showing 1 changed file with 56 additions and 45 deletions.
101 changes: 56 additions & 45 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,57 +339,68 @@ namespace fastllm {
LastTokensManager tokensManager;
std::vector <std::vector <float>* > logits;
model->dictLocker.lock();
for (auto &it: model->responseContextDict.dicts) {
if (it.second->isEnding) {
for (int isPrompt = 1; isPrompt >= 0; isPrompt--) {
if (isPrompt == 0 && seqLens.size() > 0) {
continue;
}
generationConfigs.push_back(it.second->generationConfig);
if (it.second->generationConfig.output_logits) {
it.second->resultLogits.push(new std::vector <float> ());
logits.push_back(it.second->resultLogits.back());
} else {
logits.push_back(nullptr);
}
for (auto &it: model->responseContextDict.dicts) {
if (it.second->isEnding) {
continue;
}
if (isPrompt && it.second->preTokens != 0) {
continue;
}
generationConfigs.push_back(it.second->generationConfig);
if (it.second->generationConfig.output_logits) {
it.second->resultLogits.push(new std::vector<float>());
logits.push_back(it.second->resultLogits.back());
} else {
logits.push_back(nullptr);
}

tokensManager.units.push_back(it.second->tokens);
handles.push_back(it.first);
tokensManager.units.push_back(it.second->tokens);
handles.push_back(it.first);

if (it.second->preTokens == 0) {
it.second->intParams["promptLen"] = it.second->currentTokens.size();
it.second->intParams["index"] = 0;
} else {
it.second->intParams["index"]++;
}
Data inputIds, attentionMask, curPositionIds;
std::vector <std::vector <float> > tokens;
tokens.resize(1);
for (int i : it.second->currentTokens) {
tokens[0].push_back(i);
}
model->FillLLMInputs(tokens, it.second->intParams, inputIds, attentionMask, curPositionIds);
seqLens.push_back(inputIds.Count(0));
for (int i = 0; i < inputIds.Count(0); i++) {
ids.push_back(((float*)inputIds.cpuData)[i]);
}
if (attentionMask.dims.size() == 0) {
attentionMasks.push_back(nullptr);
} else {
attentionMasks.push_back(new Data());
attentionMasks.back()->CopyFrom(attentionMask);
}
if (curPositionIds.dims.size() == 0) {
positionIds.push_back(nullptr);
} else {
positionIds.push_back(new Data());
positionIds.back()->CopyFrom(curPositionIds);
}
it.second->preTokens += seqLens.back();
for (int i = 0; i < model->block_cnt; i++) {
pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first,
&it.second->pastKeyValues[i].second));
if (it.second->preTokens == 0) {
it.second->intParams["promptLen"] = it.second->currentTokens.size();
it.second->intParams["index"] = 0;
} else {
it.second->intParams["index"]++;
}
Data inputIds, attentionMask, curPositionIds;
std::vector<std::vector<float> > tokens;
tokens.resize(1);
for (int i: it.second->currentTokens) {
tokens[0].push_back(i);
}
model->FillLLMInputs(tokens, it.second->intParams, inputIds, attentionMask,
curPositionIds);
seqLens.push_back(inputIds.Count(0));
for (int i = 0; i < inputIds.Count(0); i++) {
ids.push_back(((float *) inputIds.cpuData)[i]);
}
if (attentionMask.dims.size() == 0) {
attentionMasks.push_back(nullptr);
} else {
attentionMasks.push_back(new Data());
attentionMasks.back()->CopyFrom(attentionMask);
}
if (curPositionIds.dims.size() == 0) {
positionIds.push_back(nullptr);
} else {
positionIds.push_back(new Data());
positionIds.back()->CopyFrom(curPositionIds);
}
it.second->preTokens += seqLens.back();
for (int i = 0; i < model->block_cnt; i++) {
pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first,
&it.second->pastKeyValues[i].second));
}
if (isPrompt) {
break;
}
}
}

if (seqLens.size() > 0) {
std::vector <std::pair <Data, Data> > *pastKeyValue1;
if (seqLens.size() == 1) {
Expand Down

0 comments on commit d7e3335

Please sign in to comment.