Skip to content

Commit

Permalink
修复cudagraph&&优化CalcAucKernel&&修复fleet_last_base (#5)
Browse files Browse the repository at this point in the history
* fix IsThreadLocalCapturing

* run cuda kernel: CalcAucKernel with 512 threads

* fix_afs_api_download_dnn_plugin

* fix_fleet_last_base
  • Loading branch information
rensilin committed May 20, 2022
1 parent 0aaec55 commit b455a79
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 56 deletions.
12 changes: 6 additions & 6 deletions paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,20 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
stream, errors::PermissionDenied(
"CUDA Graph cannot be captured in default CUDA stream 0."));
capture_mode_ = mode;
if (mode == cudaStreamCaptureModeThreadLocal) {
capturing_thread_id_ = std::this_thread::get_id();
VLOG(10) << "Capturing CUDA Graph in thread local mode, thread id: "
<< capturing_thread_id_;
}
// While mode is cudaStreamCaptureModeThreadLocal, other thread may
// be running Alloc at the same time.
// To make sure other thread call IsThisThreadCapturing() is false,
// this thread_fence is necessary. That mean when IsCapturing() is true,
// this thread_fence is necessary. That means when IsCapturing() is true,
// the capture_mode_ it will get is always cudaStreamCaptureModeThreadLocal.
std::atomic_thread_fence(std::memory_order_release);
capturing_graph_.reset(new CUDAGraph());
capturing_graph_->place_ = place;
capturing_graph_->stream_ = stream;
if (mode == cudaStreamCaptureModeThreadLocal) {
capturing_thread_id_ = std::this_thread::get_id();
VLOG(10) << "Capturing CUDA Graph in thread local mode, thread id: "
<< capturing_thread_id_;
}
BeginSegmentCapture();
#endif
}
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/platform/device/gpu/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ class CUDAGraph {

static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010
return IsCapturing() &&
capture_mode_ == cudaStreamCaptureModeThreadLocal;
#else
return false;
if (IsCapturing()) {
std::atomic_thread_fence(std::memory_order_acquire);
return capture_mode_ == cudaStreamCaptureModeThreadLocal;
}
#endif
return false;
}

static bool IsThisThreadCapturing() {
Expand Down
74 changes: 54 additions & 20 deletions paddle/phi/kernels/gpu/auc_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include <cub/cub.cuh>

namespace phi {

Expand Down Expand Up @@ -88,34 +89,67 @@ __global__ void AddDataKernel(const int64_t *label_data,
}
}

template <int BLOCKDIM>
__global__ void CalcAucKernel(int64_t *stat_pos,
int64_t *stat_neg,
int num_thresholds,
double *auc,
bool need_add_batch_num) {
*auc = 0.0f;
double totPos = 0.0;
double totNeg = 0.0;
double totPosPrev = 0.0;
double totNegPrev = 0.0;
typedef cub::BlockScan<int64_t, BLOCKDIM> Int64BlockScan;
__shared__ typename Int64BlockScan::TempStorage int64_scan_storage;

int idx = num_thresholds;
typedef cub::BlockReduce<int64_t, BLOCKDIM> Int64BlockReduce;
__shared__ typename Int64BlockReduce::TempStorage int64_reduce_storage;

while (idx >= 0) {
totPosPrev = totPos;
totNegPrev = totNeg;
totPos += stat_pos[idx];
totNeg += stat_neg[idx];
*auc += (totNeg - totNegPrev) * (totPos + totPosPrev) / 2.0;
--idx;
}
typedef cub::BlockReduce<double, BLOCKDIM> DoubleBlockReduce;
__shared__ typename DoubleBlockReduce::TempStorage double_reduce_storage;

int64_t total_pos_num_local = 0; // thread_local_num
int64_t total_neg_num = 0; // global_num

double area_local = 0.0;
int block_begin_idx = 0;
for (; block_begin_idx < num_thresholds; block_begin_idx += BLOCKDIM) {
int idx = block_begin_idx + threadIdx.x;
int64_t pos_num = 0;
int64_t neg_num = 0;
if (idx <= num_thresholds) {
pos_num = stat_pos[idx];
neg_num = stat_neg[idx];
}
total_pos_num_local += pos_num;

if (totPos > 0.0 && totNeg > 0.0) {
*auc = *auc / totPos / totNeg;
int64_t block_aggregate = 0;
int64_t neg_prefix_sum = 0;
__syncthreads();
Int64BlockScan(int64_scan_storage).ExclusiveSum(neg_num, neg_prefix_sum, block_aggregate);

neg_prefix_sum += total_neg_num;
total_neg_num += block_aggregate;
area_local += static_cast<double>(pos_num) * (neg_prefix_sum + neg_prefix_sum + neg_num);
}
if (need_add_batch_num) {
stat_pos[num_thresholds + 1] += 1;
stat_neg[num_thresholds + 1] += 1;

int64_t total_pos_num = Int64BlockReduce(int64_reduce_storage).Sum(total_pos_num_local);
double area = DoubleBlockReduce(double_reduce_storage).Sum(area_local);

if (threadIdx.x == 0) {
if (block_begin_idx == num_thresholds) {
// for num_thresholds % BLOCKDIM == 0
int64_t pos_num = stat_pos[num_thresholds];
int64_t neg_num = stat_neg[num_thresholds];
area += static_cast<double>(pos_num) * (total_neg_num + total_neg_num + neg_num);
total_pos_num += pos_num;
total_neg_num += neg_num;
}
if (total_pos_num == 0 || total_neg_num == 0) {
*auc = 0.0;
} else {
*auc = area / total_pos_num / total_neg_num / 2.0;
}
if (need_add_batch_num) {
stat_pos[num_thresholds + 1] += 1;
stat_neg[num_thresholds + 1] += 1;
}
}
}

Expand Down Expand Up @@ -259,7 +293,7 @@ void AucKernel(const Context &dev_ctx,
origin_stat_pos,
origin_stat_neg);
int sum_offset = slide_steps * (num_thresholds + 1);
CalcAucKernel<<<1, 1, 0, dev_ctx.stream()>>>(origin_stat_pos + sum_offset,
CalcAucKernel<512><<<1, 512, 0, dev_ctx.stream()>>>(origin_stat_pos + sum_offset,
origin_stat_neg + sum_offset,
num_thresholds,
auc_value,
Expand Down
29 changes: 7 additions & 22 deletions python/paddle/distributed/fleet/utils/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,35 +1377,20 @@ def download(self, fs_path, local_path, multi_processes=1, overwrite=False):
client.download("hdfs:/test_hdfs_client", "./")
"""

def __subprocess_download(local_path, datas):
"""
download file from HDFS
Args:
local_path(str): the local file path
datas(str): the hdfs file path list
"""
for data in datas:
self._fs.download(local_path, data)

if not self.is_exist(fs_path):
raise FSFileNotExistsError("{} not exits".format(fs_path))
# download file
if self.is_file(fs_path):
return self._fs.download(local_path, fs_path)
if not os.path.exists(local_path):
os.mkdir(local_path)
elif not os.path.isdir(local_path):
raise FSFileNotExistsError("{} is not dir".format(local_path))
# download dir
_, all_filenames = self.ls_dir(fs_path)
all_files = [fs_path + i for i in all_filenames]
procs = []
for i in range(multi_processes):
process_datas = self._split_files(all_files, i, multi_processes)
p = multiprocessing.Process(
target=__subprocess_download, args=(local_path, process_datas))
procs.append(p)
p.start()

# complete the processes
for proc in procs:
proc.join()
for file_name in all_filenames:
local_file_name = os.path.join(local_path, os.path.split(file_name)[1])
self._fs.download(local_file_name, file_name)

def mkdirs(self, fs_path):
"""
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/fluid/incubate/fleet/utils/fleet_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ def write_xbox_donefile(self,
last_dict = json.loads(pre_content.split("\n")[-1])
last_day = last_dict["input"].split("/")[-3]
last_pass = last_dict["input"].split("/")[-2].split("-")[-1]
if last_pass == 'base':
last_pass = '0'
exist = False
if int(day) < int(last_day) or \
int(day) == int(last_day) and \
Expand Down Expand Up @@ -2066,7 +2068,8 @@ def write_xbox_donefile(self,
last_dict = json.loads(pre_content.strip().split("\n")[-1])
last_day = last_dict["input"].split("/")[-3]
last_pass = last_dict["input"].split("/")[-2].split("-")[-1]

if last_pass == 'base':
last_pass = '0'
os.remove(donefile_name)
self.rank0_info("remove %s succeed" % (donefile_name))
exist = False
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/fluid/tests/unittests/test_auc_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def setUp(self):
self.op_type = "auc"
pred = np.random.random((128, 2)).astype("float32")
labels = np.random.randint(0, 2, (128, 1)).astype("int64")
labels[8, 0] = -1
labels[18, 0] = 2
labels[64:] = -1
labels[100:] = 2
num_thresholds = 200
slide_steps = 1

Expand All @@ -96,7 +96,7 @@ def setUp(self):
python_auc = metrics.Auc(name="auc",
curve='ROC',
num_thresholds=num_thresholds)
python_auc.update(pred, labels)
python_auc.update(pred[:64], labels[:64])

pos = python_auc._stat_pos * 2
pos.append(1)
Expand Down

0 comments on commit b455a79

Please sign in to comment.