From 7ca215397016234a2fc933b21b581c3a09bacdef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Wed, 10 Jul 2024 10:37:23 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96graphllm=E7=9A=84=E5=B9=B6?= =?UTF-8?q?=E8=A1=8C=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/graph.cpp | 279 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 185 insertions(+), 94 deletions(-) diff --git a/src/graph.cpp b/src/graph.cpp index be20464f..181231a0 100644 --- a/src/graph.cpp +++ b/src/graph.cpp @@ -220,111 +220,202 @@ namespace fastllm { } } else { int batch = seqLens.size(), total = 0; - std::vector curQs, curKs, curVs, curOutputs; - curQs.resize(batch); - curKs.resize(batch); - curVs.resize(batch); - curOutputs.resize(batch); - for (int b = 0; b < batch; b++) { - excutor.Run("Split", { - {"input", allDatas[op.datas.find("q")->second]}, {"output", &curQs[b]} - }, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}}); - excutor.Run("Split", { - {"input", allDatas[op.datas.find("curk")->second]}, {"output", &curKs[b]} - }, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}}); - excutor.Run("Split", { - {"input", allDatas[op.datas.find("curv")->second]}, {"output", &curVs[b]} - }, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}}); - total += seqLens[b]; - } - std::vector axis = {0, 2, 1, 3}; - Data axisData = Data(DataType::INT32PARAM, {(int)axis.size()}); - axisData.Allocate(); - for (int i = 0; i < axisData.Count(0); i++) { - ((int32_t*)axisData.cpuData)[i] = axis[i]; - } - for (int b = 0; b < batch; b++) { - excutor.Run("PermuteSelf", { - {"input", (Data*)&curQs[b]}, {"axis", &axisData} - }, {}, {}); - curQs[b].Reshape({-1, curQs[b].dims[2], curQs[b].dims[3]}); - - excutor.Run("PermuteSelf", { - {"input", (Data*)&curKs[b]}, {"axis", &axisData} - }, {}, {}); - curKs[b].Reshape({-1, curKs[b].dims[2], curKs[b].dims[3]}); - - excutor.Run("PermuteSelf", { - {"input", (Data*)&curVs[b]}, {"axis", &axisData} - }, {}, {}); - curVs[b].Reshape({-1, curVs[b].dims[2], curVs[b].dims[3]}); + bool all1 = true; + for (int i = 0; i < seqLens.size(); i++) { + if (seqLens[i] != 1) { + all1 = false; + break; + } } - - int unitLen = op.intParams.find("unitLen")->second; - for (int b = 0; b < batch; b++) { - for (int i = 0; i < 2; i++) { - auto cache = allDatas[op.datas.find(i == 0 ? "k" : "v")->second + "_" + std::to_string(b)]; - auto cur = i == 0 ? &curKs[b] : &curVs[b]; - while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1])) - || (cache->dims.size() > 0 && cache->dims[1] + cur->dims[1] > cache->expansionDims[1])) { - std::vector newDims; - if (cache->Count(0) == 0 || cache->dims.size() == 0) { - newDims = std::vector {cur->dims[0], ((cur->dims[1] - 1) / unitLen + 1) * unitLen, cur->dims[2]}; - } else { - newDims = cache->dims; - newDims[1] += ((cur->dims[1] - 1) / unitLen + 1) * unitLen; + + if (all1) { + std::vector curQs, curKs, curVs, curOutputs; + curQs.resize(batch); + curKs.resize(batch); + curVs.resize(batch); + curOutputs.resize(batch); + auto &q = *allDatas[op.datas.find("q")->second]; + auto &k = *allDatas[op.datas.find("curk")->second]; + auto &v = *allDatas[op.datas.find("curv")->second]; + + q.Reshape({-1, q.dims[2], q.dims[3]}); + k.Reshape({-1, k.dims[2], k.dims[3]}); + v.Reshape({-1, v.dims[2], v.dims[3]}); + int embed_dim = q.dims[1] * v.dims[2]; + + std::vector qdims = {q.dims[1], 1, q.dims[2]}; + std::vector qstrides = {(uint64_t)q.dims[2], (uint64_t)q.dims[2], 1}; + std::vector kdims = {k.dims[1], 1, k.dims[2]}; + std::vector kstrides = {(uint64_t)k.dims[2], (uint64_t)k.dims[2], 1}; + std::vector vdims = {v.dims[1], 1, v.dims[2]}; + std::vector vstrides = {(uint64_t)v.dims[2], (uint64_t)v.dims[2], 1}; + for (int b = 0; b < batch; b++) { + curQs[b].dims = qdims; + curQs[b].strides = qstrides; + curQs[b].FakeFrom(q, b * q.strides[0] * q.unitSize); + curKs[b].dims = kdims; + curKs[b].strides = kstrides; + curKs[b].FakeFrom(k, b * k.strides[0] * k.unitSize); + curVs[b].dims = vdims; + curVs[b].strides = vstrides; + curVs[b].FakeFrom(v, b * v.strides[0] * v.unitSize); + } + total = batch; + + int unitLen = op.intParams.find("unitLen")->second; + for (int i = 0; i < 2; i++) { + std::vector caches, curs; + for (int b = 0; b < batch; b++) { + auto cache = allDatas[op.datas.find(i == 0 ? "k" : "v")->second + "_" + std::to_string(b)]; + auto cur = i == 0 ? &curKs[b] : &curVs[b]; + while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1])) + || (cache->dims.size() > 0 && cache->dims[1] + cur->dims[1] > cache->expansionDims[1])) { + std::vector newDims; + if (cache->Count(0) == 0 || cache->dims.size() == 0) { + newDims = std::vector {cur->dims[0], ((cur->dims[1] - 1) / unitLen + 1) * unitLen, cur->dims[2]}; + } else { + newDims = cache->dims; + newDims[1] += ((cur->dims[1] - 1) / unitLen + 1) * unitLen; + } + cache->Expansion(newDims); } - cache->Expansion(newDims); - } - excutor.Run("CatDirect", { - {"input0", cache}, {"input1", cur} - }, {}, {{"axis", 1}}); + caches.push_back(cache); + curs.push_back(cur); + } + CatDirectBatch(caches, curs, 1); } - } - for (int b = 0; b < batch; b++) { - std::string sb = "_" + std::to_string(b); - Data *k = allDatas[op.datas.find("k")->second + sb]; - Data *v = allDatas[op.datas.find("v")->second + sb]; - Data *mask = allDatas[op.datas.find("mask")->second + sb]; - excutor.Run("Attention", { - {"q", (Data*)&curQs[b]}, {"k", k}, {"v", v}, - {"mask", mask}, {"output", (Data*)&curOutputs[b]} - }, {{"scale", op.floatParams.find("scale")->second}}, - {{"maskType", 0}}); - } - - for (int b = 0; b < batch; b++) { - std::vector axis = {1, 0, 2}; + auto &attenOutput = *allDatas[op.datas.find("output")->second]; + attenOutput.dataType = q.dataType; + attenOutput.ToDevice(q.dataDevice); + attenOutput.Resize({1, batch, embed_dim}); + attenOutput.Allocate(); + std::vector curContextLayer; + std::vector qs, keys, values, masks, contexts; + curContextLayer.resize(batch); + qs.resize(batch); + keys.resize(batch); + values.resize(batch); + masks.resize(batch); + contexts.resize(batch); + + for (int b = 0; b < batch; b++) { + std::string sb = "_" + std::to_string(b); + qs[b] = (&curQs[b]); + keys[b] = allDatas[op.datas.find("k")->second + sb]; + values[b] = allDatas[op.datas.find("v")->second + sb]; + masks[b] = allDatas[op.datas.find("mask")->second + sb]; + curContextLayer[b].FakeFrom(attenOutput, b * embed_dim * attenOutput.unitSize); + contexts[b] = (&curContextLayer[b]); + } + AttentionBatch(qs, keys, values, masks, contexts, qs[0]->dims[0] / values[0]->dims[0], op.floatParams.find("scale")->second, 1); + } else { + std::vector curQs, curKs, curVs, curOutputs; + curQs.resize(batch); + curKs.resize(batch); + curVs.resize(batch); + curOutputs.resize(batch); + for (int b = 0; b < batch; b++) { + excutor.Run("Split", { + {"input", allDatas[op.datas.find("q")->second]}, {"output", &curQs[b]} + }, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}}); + excutor.Run("Split", { + {"input", allDatas[op.datas.find("curk")->second]}, {"output", &curKs[b]} + }, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}}); + excutor.Run("Split", { + {"input", allDatas[op.datas.find("curv")->second]}, {"output", &curVs[b]} + }, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}}); + total += seqLens[b]; + } + std::vector axis = {0, 2, 1, 3}; Data axisData = Data(DataType::INT32PARAM, {(int)axis.size()}); axisData.Allocate(); for (int i = 0; i < axisData.Count(0); i++) { ((int32_t*)axisData.cpuData)[i] = axis[i]; } - Data *output = (Data*)&curOutputs[b]; - excutor.Run("PermuteSelf", { - {"input", output}, {"axis", &axisData} - }, {}, {}); - output->Reshape({seqLens[b], 1, -1}); - excutor.Run("PermuteSelf", { - {"input", output}, {"axis", &axisData} - }, {}, {}); - } + for (int b = 0; b < batch; b++) { + excutor.Run("PermuteSelf", { + {"input", (Data*)&curQs[b]}, {"axis", &axisData} + }, {}, {}); + curQs[b].Reshape({-1, curQs[b].dims[2], curQs[b].dims[3]}); - auto lastOutput = allDatas[op.datas.find("output")->second]; - for (int b = 0; b < batch; b++) { - Data *output = (Data*)&curOutputs[b]; - if (b == 0) { - lastOutput->dataType = output->dataType; - std::vector dims = output->dims; - dims[1] = 0; - lastOutput->Resize(dims); - dims[1] = total; - lastOutput->Expansion(dims); + excutor.Run("PermuteSelf", { + {"input", (Data*)&curKs[b]}, {"axis", &axisData} + }, {}, {}); + curKs[b].Reshape({-1, curKs[b].dims[2], curKs[b].dims[3]}); + + excutor.Run("PermuteSelf", { + {"input", (Data*)&curVs[b]}, {"axis", &axisData} + }, {}, {}); + curVs[b].Reshape({-1, curVs[b].dims[2], curVs[b].dims[3]}); + } + + int unitLen = op.intParams.find("unitLen")->second; + for (int b = 0; b < batch; b++) { + for (int i = 0; i < 2; i++) { + auto cache = allDatas[op.datas.find(i == 0 ? "k" : "v")->second + "_" + std::to_string(b)]; + auto cur = i == 0 ? &curKs[b] : &curVs[b]; + while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1])) + || (cache->dims.size() > 0 && cache->dims[1] + cur->dims[1] > cache->expansionDims[1])) { + std::vector newDims; + if (cache->Count(0) == 0 || cache->dims.size() == 0) { + newDims = std::vector {cur->dims[0], ((cur->dims[1] - 1) / unitLen + 1) * unitLen, cur->dims[2]}; + } else { + newDims = cache->dims; + newDims[1] += ((cur->dims[1] - 1) / unitLen + 1) * unitLen; + } + cache->Expansion(newDims); + } + excutor.Run("CatDirect", { + {"input0", cache}, {"input1", cur} + }, {}, {{"axis", 1}}); + } + } + + for (int b = 0; b < batch; b++) { + std::string sb = "_" + std::to_string(b); + Data *k = allDatas[op.datas.find("k")->second + sb]; + Data *v = allDatas[op.datas.find("v")->second + sb]; + Data *mask = allDatas[op.datas.find("mask")->second + sb]; + excutor.Run("Attention", { + {"q", (Data*)&curQs[b]}, {"k", k}, {"v", v}, + {"mask", mask}, {"output", (Data*)&curOutputs[b]} + }, {{"scale", op.floatParams.find("scale")->second}}, + {{"maskType", 0}}); + } + + for (int b = 0; b < batch; b++) { + std::vector axis = {1, 0, 2}; + Data axisData = Data(DataType::INT32PARAM, {(int)axis.size()}); + axisData.Allocate(); + for (int i = 0; i < axisData.Count(0); i++) { + ((int32_t*)axisData.cpuData)[i] = axis[i]; + } + Data *output = (Data*)&curOutputs[b]; + excutor.Run("PermuteSelf", { + {"input", output}, {"axis", &axisData} + }, {}, {}); + output->Reshape({seqLens[b], 1, -1}); + excutor.Run("PermuteSelf", { + {"input", output}, {"axis", &axisData} + }, {}, {}); + } + + auto lastOutput = allDatas[op.datas.find("output")->second]; + for (int b = 0; b < batch; b++) { + Data *output = (Data*)&curOutputs[b]; + if (b == 0) { + lastOutput->dataType = output->dataType; + std::vector dims = output->dims; + dims[1] = 0; + lastOutput->Resize(dims); + dims[1] = total; + lastOutput->Expansion(dims); + } + excutor.Run("CatDirect", { + {"input0", lastOutput}, {"input1", output} + }, {}, {{"axis", 1}}); } - excutor.Run("CatDirect", { - {"input0", lastOutput}, {"input1", output} - }, {}, {{"axis", 1}}); } } } else if (op.type == "SplitLastTokenStates") {