Skip to content

Commit

Permalink
add dymf (PaddlePaddle#10)
Browse files Browse the repository at this point in the history
* dymf tmp

* add dymf tmp

* local test change

* pull thread pool

* fix conflict

* delete unuse log

* local change for mirrow 0

* fix dymf

* code clean

* fix code clean

* code clean

* code clean

* fix dymf

* fix dymf

* add endpass optimize

* clean code

* fix endpass optimize

* fix

* fix

Co-authored-by: yaoxuefeng6 <yaoxuefeng@baidu.com>
Co-authored-by: Thunderbrook <a754913769@163.com>
  • Loading branch information
3 people committed Jun 7, 2022
1 parent b455a79 commit 022b54e
Show file tree
Hide file tree
Showing 28 changed files with 1,381 additions and 169 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2436,7 +2436,6 @@ void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec,
if (feed == nullptr) {
continue;
}

auto& slot_offset = offset_[j];
slot_offset.clear();
slot_offset.reserve(num + 1);
Expand Down
33 changes: 33 additions & 0 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
&data_feed_desc_);
}

template <typename T>
std::vector<std::string> DatasetImpl<T>::GetSlots() {
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
use_slots_.clear();
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
use_slots_.push_back(slot.name());
}
std::cout << "dataset use slots: ";
for (auto s : use_slots_) {
std::cout << s << " | ";
}
std::cout << " end " <<std::endl;
return use_slots_;
}

template <typename T>
void DatasetImpl<T>::SetChannelNum(int channel_num) {
channel_num_ = channel_num;
Expand Down Expand Up @@ -1773,5 +1789,22 @@ void SlotRecordDataset::DynamicAdjustReadersNum(int thread_num) {
PrepareTrain();
}

std::vector<std::string> SlotRecordDataset::GetSlots() {
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
use_slots_.clear();
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.type() == "uint64" || slot.type() == "uint32") {
use_slots_.push_back(slot.name());
}
}
std::cout << "dataset use slots: ";
for (auto s : use_slots_) {
std::cout << s << " | ";
}
std::cout << " end " <<std::endl;
return use_slots_;
}

} // end namespace framework
} // end namespace paddle
5 changes: 5 additions & 0 deletions paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ class Dataset {
// set fleet send sleep seconds
virtual void SetFleetSendSleepSeconds(int seconds) = 0;

virtual std::vector<std::string> GetSlots() = 0;

protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
Expand Down Expand Up @@ -246,6 +248,7 @@ class DatasetImpl : public Dataset {
bool discard_remaining_ins = false);
virtual void DynamicAdjustReadersNum(int thread_num);
virtual void SetFleetSendSleepSeconds(int seconds);
virtual std::vector<std::string> GetSlots();
/* for enable_heterps_
virtual void EnableHeterps(bool enable_heterps) {
enable_heterps_ = enable_heterps;
Expand Down Expand Up @@ -321,6 +324,7 @@ class DatasetImpl : public Dataset {
int64_t global_index_ = 0;
std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
std::vector<T> input_records_; // only for paddleboxdatafeed
std::vector<std::string> use_slots_;
bool enable_heterps_ = false;
};

Expand Down Expand Up @@ -379,6 +383,7 @@ class SlotRecordDataset : public DatasetImpl<SlotRecord> {
bool discard_remaining_ins);
virtual void PrepareTrain();
virtual void DynamicAdjustReadersNum(int thread_num);
virtual std::vector<std::string> GetSlots();

protected:
bool enable_heterps_ = true;
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/framework/fleet/fleet_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
int node_num, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
VLOG(3) << "Going to init worker";
VLOG(0) << "Going to init worker";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_worker(dist_desc,
Expand Down Expand Up @@ -126,7 +126,7 @@ void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,

void FleetWrapper::GatherClients(const std::vector<uint64_t>& host_sign_list) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to gather client ips";
VLOG(0) << "Going to gather client ips";
size_t len = host_sign_list.size();
pslib_ptr_->gather_clients(const_cast<uint64_t*>(host_sign_list.data()), len);
#endif
Expand All @@ -142,7 +142,7 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() {

void FleetWrapper::CreateClient2ClientConnection() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to create client2client connection";
VLOG(0) << "Going to create client2client connection";
pslib_ptr_->create_client2client_connection(client2client_request_timeout_ms_,
client2client_connect_timeout_ms_,
client2client_max_retry_);
Expand Down Expand Up @@ -1054,7 +1054,7 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
int slot_offset = 0;
int grad_dim = 0;
// don't worry, user do not have to care about all these flags
if (accesor == "DownpourCtrAccessor") {
if (accesor == "DownpourCtrAccessor" || accesor == "DownpourCtrDymfAccessor") {
dump_slot = true;
slot_offset = 1;
grad_dim = fea_dim - 2;
Expand Down
5 changes: 0 additions & 5 deletions paddle/fluid/framework/fleet/heter_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ class HeterContext {
for (size_t i = 0; i < feature_dim_keys_.size(); i++) {
feature_dim_keys_[i].resize(dim_num);
value_dim_ptr_[i].resize(dim_num);
if (i == 0) {
for (int j = 0; j < dim_num; j++) {
feature_dim_keys_[i][j].push_back(0);
}
}
}
device_values_.resize(device_num);
device_dim_values_.resize(device_num);
Expand Down
73 changes: 73 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace framework {

typedef uint64_t FeatureKey;

/*
struct FeatureValue {
float delta_score;
float show;
Expand Down Expand Up @@ -65,6 +66,78 @@ struct FeaturePushValue {
return out;
}
};
*/

struct FeatureValue {
float delta_score;
float show;
float clk;
int slot;
float lr;
float lr_g2sum;
int mf_size;
int mf_dim;
uint64_t cpu_ptr;
float mf[0];

friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) {
out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot
<< " lr: " << val.lr << " mf_dim: " << val.mf_dim << "cpuptr: " << val.cpu_ptr
<< " mf_size: " << val.mf_size << " mf:";
for (int i = 0; i < val.mf_dim + 1; ++i) {
out << " " << val.mf[i];
}
return out;
}
__device__ __forceinline__ void operator=(const FeatureValue& in) {
delta_score = in.delta_score;
show = in.show;
clk = in.clk;
slot = in.slot;
lr = in.lr;
lr_g2sum = in.lr_g2sum;
mf_size = in.mf_size;
mf_dim = in.mf_dim;
cpu_ptr = in.cpu_ptr;
for (int i = 0; i < mf_dim + 1; i++) {
mf[i] = in.mf[i];
}
}
};

struct FeaturePushValue {
float show;
float clk;
int slot;
float lr_g;
int mf_dim;
float mf_g[0];

__device__ __forceinline__ FeaturePushValue
operator+(const FeaturePushValue& a) const {
FeaturePushValue out;
out.slot = a.slot;
out.mf_dim = a.mf_dim;
out.show = a.show + show;
out.clk = a.clk + clk;
out.lr_g = a.lr_g + lr_g;
// out.mf_g = a.mf_g;
for (int i = 0; i < out.mf_dim; ++i) {
out.mf_g[i] = a.mf_g[i] + mf_g[i];
}
return out;
}
__device__ __forceinline__ void operator=(const FeaturePushValue& in) {
show = in.show;
clk = in.clk;
slot = in.slot;
lr_g = in.lr_g;
mf_dim = in.mf_dim;
for (int i = 0; i < mf_dim; i++) {
mf_g[i] = in.mf_g[i];
}
}
};

} // end namespace framework
} // end namespace paddle
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class HashTable {
HashTable& operator=(const HashTable&) = delete;
void insert(const KeyType* d_keys, const ValType* d_vals, size_t len,
gpuStream_t stream);
void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index,
gpuStream_t stream);
void insert(const KeyType* d_keys, size_t len, char* pool, size_t feature_value_size,
size_t start_index, gpuStream_t stream);
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
gpuStream_t stream);
void get(const KeyType* d_keys, char* d_vals, size_t len, gpuStream_t stream);
Expand Down
49 changes: 41 additions & 8 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@ __global__ void insert_kernel(Table* table,
template <typename Table>
__global__ void insert_kernel(Table* table,
const typename Table::key_type* const keys,
size_t len, char* pool, int start_index) {
size_t len, char* pool, size_t feature_value_size,
int start_index) {
ReplaceOp<typename Table::mapped_type> op;
thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

if (i < len) {
kv.first = keys[i];
kv.second = (Table::mapped_type)(pool + (start_index + i) * 80);
uint64_t offset = uint64_t(start_index + i) * feature_value_size;
kv.second = (Table::mapped_type)(pool + offset);
auto it = table->insert(kv, op);
assert(it != table->end() && "error: insert fails: table is full");
}
Expand All @@ -77,14 +79,43 @@ __global__ void search_kernel(Table* table,
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
const typename Table::key_type* const keys,
char* const vals, size_t len,
char* vals, size_t len,
size_t pull_feature_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
auto it = table->find(keys[i]);

if (it != table->end()) {
*(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second);
uint64_t offset = i * pull_feature_value_size;
FeatureValue* cur = (FeatureValue*)(vals + offset);
FeatureValue& input = *(FeatureValue*)(it->second);
cur->slot = input.slot;
cur->show = input.show;
cur->clk = input.clk;
cur->mf_dim = input.mf_dim;
cur->lr = input.lr;
cur->mf_size = input.mf_size;
cur->cpu_ptr = input.cpu_ptr;
cur->delta_score = input.delta_score;
cur->lr_g2sum = input.lr_g2sum;
for(int j = 0; j < cur->mf_dim + 1; ++j) {
cur->mf[j] = input.mf[j];
}
} else {
if (keys[i] != 0) printf("pull miss key: %d",keys[i]);
FeatureValue* cur = (FeatureValue*)(vals + i * pull_feature_value_size);
cur->delta_score = 0;
cur->show = 0;
cur->clk = 0;
cur->slot = -1;
cur->lr = 0;
cur->lr_g2sum = 0;
cur->mf_size = 0;
cur->mf_dim = 8;
cur->cpu_ptr;
for (int j = 0; j < cur->mf_dim + 1; j++) {
cur->mf[j] = 0;
}

}
}
}
Expand Down Expand Up @@ -191,7 +222,7 @@ __global__ void dy_mf_update_kernel(Table* table,
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
sgd.dy_mf_update_value((it.getter())->second, *cur);
} else {
printf("yxf::push miss key: %d", keys[i]);
if(keys[i] != 0) printf("push miss key: %d", keys[i]);
}
}
}
Expand Down Expand Up @@ -248,7 +279,9 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
char* pool, size_t start_index,
char* pool,
size_t feature_value_size,
size_t start_index,
gpuStream_t stream) {
if (len == 0) {
return;
Expand All @@ -258,7 +291,7 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
return;
}
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
pool, start_index);
pool, feature_value_size, start_index);
}

template <typename KeyType, typename ValType>
Expand Down
Loading

0 comments on commit 022b54e

Please sign in to comment.