Skip to content

Commit

Permalink
format general accessor (PaddlePaddle#74)
Browse files Browse the repository at this point in the history
* format general accessor;test=develop

* format general accessor;test=develop
  • Loading branch information
danleifeng committed Jul 28, 2022
1 parent bba158c commit eed5f6b
Show file tree
Hide file tree
Showing 15 changed files with 418 additions and 392 deletions.
22 changes: 11 additions & 11 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const int CUDA_NUM_THREADS = platform::PADDLE_CUDA_NUM_THREADS;
#define GET_BLOCK(N) ((N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS)
#define CUDA_BLOCK(N) GET_BLOCK(N), CUDA_NUM_THREADS, 0

template <typename FVAccessor>
template <typename GPUAccessor>
__global__ void PullCopy(float** dest,
const float* src,
const int64_t* len,
Expand All @@ -31,7 +31,7 @@ __global__ void PullCopy(float** dest,
uint64_t** keys,
uint64_t max_val_size,
int* gpu_dim,
FVAccessor feature_value_accessor) {
GPUAccessor gpu_accessor) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
Expand All @@ -47,7 +47,7 @@ __global__ void PullCopy(float** dest,
float* feature_value_ptr =
(float*)((char*)src + uint64_t(i) * uint64_t(max_val_size));
int mf_dim = gpu_dim[x] - 3;
feature_value_accessor.Select(
gpu_accessor.Select(
dest[x] + y * (mf_dim + 3), feature_value_ptr, keys[x] + y, mf_dim);
}
}
Expand Down Expand Up @@ -98,7 +98,7 @@ __global__ void PullDedupCopy(
}
}

template <typename FVAccessor>
template <typename GPUAccessor>
__global__ void PushCopyWithPool(float* dest,
float** src,
int64_t* len,
Expand All @@ -108,7 +108,7 @@ __global__ void PushCopyWithPool(float* dest,
int* slot_vector,
int* mf_dim_vector,
size_t grad_value_size,
FVAccessor feature_value_accessor) {
GPUAccessor gpu_accessor) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
Expand All @@ -123,19 +123,19 @@ __global__ void PushCopyWithPool(float* dest,
int y = i - (x ? len[low - 1] : 0);
float* cur = (float*)((char*)dest + i * grad_value_size);

cur[feature_value_accessor.common_push_value.SlotIndex()] =
cur[gpu_accessor.common_push_value.SlotIndex()] =
(float)slot_vector[x];
int mf_dim = mf_dim_vector[x];
cur[feature_value_accessor.common_push_value.MfDimIndex()] = mf_dim;
cur[gpu_accessor.common_push_value.MfDimIndex()] = mf_dim;

cur[feature_value_accessor.common_push_value.ShowIndex()] =
cur[gpu_accessor.common_push_value.ShowIndex()] =
*(src[x] + y * (mf_dim + 3));
cur[feature_value_accessor.common_push_value.ClickIndex()] =
cur[gpu_accessor.common_push_value.ClickIndex()] =
*(src[x] + y * (mf_dim + 3) + 1);
cur[feature_value_accessor.common_push_value.EmbedGIndex()] =
cur[gpu_accessor.common_push_value.EmbedGIndex()] =
*(src[x] + y * (mf_dim + 3) + 2) * -1. * bs;
for (int j = 0; j < mf_dim; j++) {
cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] =
cur[gpu_accessor.common_push_value.EmbedxGIndex() + j] =
*(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs;
}
}
Expand Down
40 changes: 17 additions & 23 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,10 @@ typedef uint64_t FeatureKey;
#define TYPEALIGN(ALIGNVAL, LEN) \
(((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))

class FeatureValueAccessor {
public:
__host__ __device__ FeatureValueAccessor() {}
__host__ __device__ ~FeatureValueAccessor() {}

__host__ __device__ virtual int Configure(
std::unordered_map<std::string, float> config) {
_config = config;
Initialize();
return 0;
}
__host__ __device__ virtual int Initialize() = 0;

protected:
std::unordered_map<std::string, float> _config;
};

// adagrad: embed_sgd_dim=1, embedx_sgd_dim=1,embedx_dim=n
// adam std: embed_sgd_dim=4, embedx_sgd_dim=n*2+2,embedx_dim=n
// adam shared: embed_sgd_dim=4, embedx_sgd_dim=4,embedx_dim=n
class CommonFeatureValueAccessor : public FeatureValueAccessor {
class CommonFeatureValueAccessor {
public:
struct CommonFeatureValue {
/*
Expand Down Expand Up @@ -256,7 +239,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor {
__host__ __device__ CommonFeatureValueAccessor() {}
__host__ __device__ ~CommonFeatureValueAccessor() {}

__host__ __device__ virtual int Initialize() {
__host__ int Initialize() {
int optimizer_type = (_config.find("optimizer_type") == _config.end())
? 1
: int(_config["optimizer_type"]);
Expand All @@ -279,6 +262,12 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor {
return 0;
}

__host__ int Configure(std::unordered_map<std::string, float>& config) {
_config = config;
Initialize();
return 0;
}

// // build阶段从cpu_val赋值给gpu_val
__host__ void BuildFill(
float* gpu_val, void* cpu,
Expand Down Expand Up @@ -561,6 +550,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor {
}

public:
std::unordered_map<std::string, float> _config;
CommonFeatureValue common_feature_value;
CommonPushValue common_push_value;
CommonPullValue common_pull_value;
Expand Down Expand Up @@ -728,6 +718,10 @@ class AccessorWrapper : public VirtualAccessor {
return gpu_accessor_.common_pull_value.Size(mf_dim);
}

GPUAccessor* AccessorPtr() {
return &gpu_accessor_;
}

virtual void BuildFill(void* gpu_val, void* cpu_val,
paddle::distributed::ValueAccessor* cpu_table_accessor,
int mf_dim) {
Expand Down Expand Up @@ -859,10 +853,10 @@ class AccessorWrapper : public VirtualAccessor {
GPUAccessor gpu_accessor_;
};

class GlobalAccessorTransfor {
class GlobalAccessorFactory {
public:
static GlobalAccessorTransfor& GetInstance() {
static GlobalAccessorTransfor ins;
static GlobalAccessorFactory& GetInstance() {
static GlobalAccessorFactory ins;
return ins;
}
void Init(std::string accessor_type) {
Expand All @@ -872,7 +866,7 @@ class GlobalAccessorTransfor {
if (accessor_type == "CtrDymfAccessor") {
accessor_wrapper_ptr_ = new AccessorWrapper<CommonFeatureValueAccessor>();
} else {
VLOG(0) << "GlobalAccessorTransfor Init not support accessor_type:"
VLOG(0) << "GlobalAccessorFactory Init not support accessor_type:"
<< accessor_type;
accessor_wrapper_ptr_ = new AccessorWrapper<CommonFeatureValueAccessor>();
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ class HashTable {
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
StreamType stream);

template <typename StreamType, typename FVAccessor>
template <typename StreamType, typename GPUAccessor>
void get(const KeyType* d_keys, char* d_vals, size_t len, StreamType stream,
FVAccessor& fv_accessor);
GPUAccessor& fv_accessor);

void show();

Expand Down
22 changes: 11 additions & 11 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ __global__ void search_kernel(Table* table,
}
}

template <typename Table, typename FVAccessor>
template <typename Table, typename GPUAccessor>
__global__ void dy_mf_search_kernel(Table* table,
const typename Table::key_type* const keys,
char* vals, size_t len,
size_t pull_feature_value_size,
FVAccessor feature_value_accessor) {
GPUAccessor gpu_accessor) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
// return;
if (i < len) {
Expand All @@ -94,7 +94,7 @@ __global__ void dy_mf_search_kernel(Table* table,
uint64_t offset = i * pull_feature_value_size;
float* cur = (float*)(vals + offset);
float* input = it->second;
feature_value_accessor.PullValueFill(cur, input);
gpu_accessor.PullValueFill(cur, input);
}
}
}
Expand Down Expand Up @@ -180,10 +180,10 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
}

template <typename KeyType, typename ValType>
template <typename StreamType, typename FVAccessor>
template <typename StreamType, typename GPUAccessor>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
size_t len, StreamType stream,
FVAccessor& fv_accessor) {
GPUAccessor& fv_accessor) {
if (len == 0) {
return;
}
Expand Down Expand Up @@ -390,17 +390,17 @@ template void HashTable<unsigned long, float*>::dump_to_cpu<cudaStream_t>(
int devid, cudaStream_t stream);

template void
HashTable<unsigned long, float*>::update<SparseAdagradOptimizer, cudaStream_t>(
HashTable<unsigned long, float*>::update<SparseAdagradOptimizer<CommonFeatureValueAccessor>, cudaStream_t>(
const unsigned long* d_keys, const char* d_grads, size_t len,
SparseAdagradOptimizer sgd, cudaStream_t stream);
SparseAdagradOptimizer<CommonFeatureValueAccessor> sgd, cudaStream_t stream);
template void
HashTable<unsigned long, float*>::update<SparseAdamOptimizer, cudaStream_t>(
HashTable<unsigned long, float*>::update<SparseAdamOptimizer<CommonFeatureValueAccessor>, cudaStream_t>(
const unsigned long* d_keys, const char* d_grads, size_t len,
SparseAdamOptimizer sgd, cudaStream_t stream);
SparseAdamOptimizer<CommonFeatureValueAccessor> sgd, cudaStream_t stream);
template void HashTable<unsigned long, float*>::update<
SparseAdamSharedOptimizer, cudaStream_t>(const unsigned long* d_keys,
SparseAdamSharedOptimizer<CommonFeatureValueAccessor>, cudaStream_t>(const unsigned long* d_keys,
const char* d_grads, size_t len,
SparseAdamSharedOptimizer sgd,
SparseAdamSharedOptimizer<CommonFeatureValueAccessor> sgd,
cudaStream_t stream);

// template void HashTable<unsigned long,
Expand Down
10 changes: 4 additions & 6 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ namespace framework {
(((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))

template <typename KeyType, typename ValType, typename GradType,
typename FVAccessor>
typename GPUAccessor>
class HeterComm {
public:
HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource, GPUAccessor& gpu_accessor);
virtual ~HeterComm();
HeterComm(const HeterComm&) = delete;
HeterComm& operator=(const HeterComm&) = delete;
Expand Down Expand Up @@ -123,9 +124,6 @@ class HeterComm {
max_mf_dim_ = max_mf_dim;
}

void set_accessor(FVAccessor& accessor) {
feature_value_accessor_ = accessor;
}
#endif

bool need_transfer(int send_id, int receive_id) {
Expand Down Expand Up @@ -267,7 +265,7 @@ class HeterComm {
int block_size_{256};
std::unique_ptr<HeterCommKernel> heter_comm_kernel_;

FVAccessor feature_value_accessor_;
GPUAccessor gpu_accessor_;

private:
int topo_aware_{0};
Expand Down
Loading

0 comments on commit eed5f6b

Please sign in to comment.