diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 9cd733b3ff636..b1bdeb2612c80 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -203,7 +203,7 @@ class HeterComm { void init_path(); - void create_storage(int start_index, int end_index, int keylen, int vallen); + void create_storage(int start_index, int end_index, size_t keylen, size_t vallen); void destroy_storage(int start_index, int end_index); void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, GradType* src_val); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index a0f58c7798acc..1292f74542bce 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -243,20 +243,20 @@ void HeterComm::init_path() { template void HeterComm::create_storage(int start_index, int end_index, - int keylen, - int vallen) { + size_t keylen, + size_t vallen) { auto& allocator = allocators_[start_index]; auto& nodes = path_[start_index][end_index].nodes_; for (size_t i = 0; i < nodes.size(); ++i) { platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num)); - allocator->DeviceAllocate( + PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceAllocate( resource_->dev_id(nodes[i].gpu_num), (void**)&(nodes[i].key_storage), // NOLINT - keylen, resource_->remote_stream(nodes[i].gpu_num, start_index)); - allocator->DeviceAllocate( + keylen, resource_->remote_stream(nodes[i].gpu_num, start_index))); + PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceAllocate( resource_->dev_id(nodes[i].gpu_num), (void**)&(nodes[i].val_storage), // NOLINT - vallen, resource_->remote_stream(nodes[i].gpu_num, start_index)); + vallen, resource_->remote_stream(nodes[i].gpu_num, start_index))); nodes[i].key_bytes_len = keylen; nodes[i].val_bytes_len = vallen; @@ -271,10 +271,10 @@ void HeterComm::destroy_storage(int start_index, for (size_t i = 0; i < nodes.size(); ++i) { platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num)); - allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), - nodes[i].key_storage); - allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), - nodes[i].val_storage); + PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), + nodes[i].key_storage)); + PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), + nodes[i].val_storage)); } }