Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#60 from YaoCheng8667/paddlebox-yc
Browse files Browse the repository at this point in the history
modify interface for compress push
  • Loading branch information
YaoCheng8667 committed Mar 22, 2024
2 parents 93730f6 + 633c2d0 commit dd20513
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
18 changes: 14 additions & 4 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2808,26 +2808,31 @@ void PadBoxSlotDataset::PrepareTrain(void) {
->AddBatchOffset(offset[i]);
}
#ifdef PADDLE_WITH_XPU_KP
using BatchData = std::vector<std::pair<uint64_t*, int>>;
using BatchData = std::vector<std::vector<std::pair<uint64_t*, int>>>; // devid -> dev_batch_data

VLOG(0) << "PadBoxSlotDataset::PrepareTrain with pv_merge offset size:" << offset.size()
<< ", thread_num:" << thread_num_;
auto data_func = [this, offset] (int batch_idx, BatchData * out_data) {
BatchData & batch_data = *out_data;
batch_data.clear();
batch_data.resize(thread_num_);

int offset_idx = batch_idx * thread_num_;
CHECK(offset_idx + thread_num_ <= (int)offset.size())
<< "offset_idx:" << offset_idx
<< ", thread_num_:" << thread_num_
<< "offset.size:" << offset.size();
for (int j = 0; j < thread_num_; j++) {
int dev_id = j;
auto & dev_batch_data = batch_data[dev_id];
auto & offset_pair = offset[offset_idx + j];
for (int k = 0; k < offset_pair.second; k++) {
auto & pv_ins = input_pv_ins_[offset_pair.first + k]->ads;
size_t num = 0;
for (auto & rec : pv_ins) {
for (auto& idx : used_fea_index_) {
uint64_t* feas = rec->slot_uint64_feasigns_.get_values(idx, &num);
batch_data.push_back(std::make_pair(feas, num));
dev_batch_data.push_back(std::make_pair(feas, num));
}
}
}
Expand Down Expand Up @@ -2857,25 +2862,30 @@ void PadBoxSlotDataset::PrepareTrain(void) {
->AddBatchOffset(offset[i]);
}
#ifdef PADDLE_WITH_XPU_KP
using BatchData = std::vector<std::pair<uint64_t*, int>>;
using BatchData = std::vector<std::vector<std::pair<uint64_t*, int>>>; // devid -> dev_batch_data
VLOG(0) << "PadBoxSlotDataset::PrepareTrain offset size:" << offset.size()
<< ", thread_num:" << thread_num_;
auto data_func = [this, offset] (int batch_idx, BatchData * out_data) {
BatchData & batch_data = *out_data;
batch_data.clear();
batch_data.resize(thread_num_);

int offset_idx = batch_idx * thread_num_;
CHECK(offset_idx + thread_num_ <= (int)offset.size())
<< "offset_idx:" << offset_idx
<< ", thread_num_:" << thread_num_
<< "offset.size:" << offset.size();
for (int j = 0; j < thread_num_; j++) {
int dev_id = j;
auto & dev_batch_data = batch_data[dev_id];

auto & offset_pair = offset[offset_idx + j];
for (int k = 0; k < offset_pair.second; k++) {
auto & rec = input_records_[offset_pair.first + k];
size_t num = 0;
for (auto& idx : used_fea_index_) {
uint64_t* feas = rec->slot_uint64_feasigns_.get_values(idx, &num);
batch_data.push_back(std::make_pair(feas, num));
dev_batch_data.push_back(std::make_pair(feas, num));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/box_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ void BoxWrapper::GetFeatureOffsetInfo(void) {

#ifdef PADDLE_WITH_XPU_KP
void BoxWrapper::SetDataFuncForCacheManager(int batch_num,
std::function<void(int, std::vector<std::pair<uint64_t*, int>>*)> data_func) {
std::function<void(int, std::vector<std::vector<std::pair<uint64_t*, int>>>*)> data_func) {
boxps_ptr_->SetDataFuncForCacheManager(batch_num, data_func, &fid2sign_map_);
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ class BoxWrapper {

#ifdef PADDLE_WITH_XPU_KP
void SetDataFuncForCacheManager(int batch_num,
std::function<void(int, std::vector<std::pair<uint64_t*, int>>*)> data_func);
std::function<void(int, std::vector<std::vector<std::pair<uint64_t*, int>>>*)> data_func);
int PrepareNextBatch(int dev_id);
std::vector<uint64_t> * GetFid2SginMap() { return fid2sign_map_; }
#endif
Expand Down

0 comments on commit dd20513

Please sign in to comment.