diff --git a/doc/Doxyfile b/doc/Doxyfile new file mode 100644 index 0000000..a940c1f --- /dev/null +++ b/doc/Doxyfile @@ -0,0 +1,13 @@ +INPUT = ../src ../README.md +FILE_PATTERNS = *.h *.cc *.cu *.cuh *.md +EXTENSION_MAPPING = cu=C++ +EXTRACT_ALL = YES +EXTRACT_ANON_NSPACES = YES +EXCLUDE_PATTERNS = *.py *.R +DOXYFILE_ENCODING = UTF-8 +PROJECT_NAME = KMeansCUDA +OUTPUT_LANGUAGE = English +GENERATE_XML = NO +GENERATE_LATEX = NO +GENERATE_HTML = YES +HTML_OUTPUT = doxyhtml/ diff --git a/src/kmcuda.cc b/src/kmcuda.cc index 7eac28f..eec8789 100644 --- a/src/kmcuda.cc +++ b/src/kmcuda.cc @@ -15,7 +15,7 @@ #include "private.h" - +/// Used in kmeans_cuda() to validate function arguments. static KMCUDAResult check_kmeans_args( float tolerance, float yinyang_t, @@ -659,7 +659,7 @@ KMCUDAResult knn_cuda( cudaSetDevice(device_ptrs); CUCH(cudaMemcpy( asses_on_host.get(), assignments, samples_size * sizeof(uint32_t), - cudaMemcpyDeviceToHost), kmcudaRuntimeError); + cudaMemcpyDeviceToHost), kmcudaMemoryCopyError); #pragma omp parallel for for (uint32_t s = 0; s < samples_size; s++) { asses_with_idxs[s] = std::make_tuple(asses_on_host[s], s); diff --git a/src/kmcuda.h b/src/kmcuda.h index 4b1c2a3..8b013b9 100644 --- a/src/kmcuda.h +++ b/src/kmcuda.h @@ -1,26 +1,82 @@ #ifndef KMCUDA_KMCUDA_H #define KMCUDA_KMCUDA_H +/*! @mainpage KMeansCUDA documentation + * + * @section s1 Description + * + * K-means and K-nn on NVIDIA CUDA which are designed for production usage and + * simplicity. + * + * K-means is based on ["Yinyang K-Means: A Drop-In Replacement + * of the Classic K-Means with Consistent Speedup"](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/ding15.pdf). + * While it introduces some overhead and many conditional clauses + * which are bad for CUDA, it still shows 1.6-2x speedup against the Lloyd + * algorithm. K-nearest neighbors employ the same triangle inequality idea and + * require precalculated centroids and cluster assignments, similar to the flattened + * ball tree. + * + * Project: https://github.com/src-d/kmcuda + * + * README: @ref ignore_this_doxygen_anchor + * + * @section s2 C/C++ API + * + * kmcuda.h exports two functions: kmeans_cuda() and knn_cuda(). They are not + * thread safe. + * + * @section s3 Python 3 API + * + * The shared library exports kmeans_cuda() and knn_cuda() Python wrappers. + * + * @section s4 R API + * + * The shared library exports kmeans_cuda() and knn_cuda() R wrappers (.External). + * + */ + #include +/// All possible error codes in public API. typedef enum { + /// Everything's all right. kmcudaSuccess = 0, + /// Arguments which were passed into a function failed the validation. kmcudaInvalidArguments, + /// The requested CUDA device does not exist. kmcudaNoSuchDevice, + /// Failed to allocate memory on CUDA device. Too big size? Switch off Yinyang? kmcudaMemoryAllocationFailure, + /// Something bad and unidentified happened on the CUDA side. kmcudaRuntimeError, + /// Failed to copy memory to/from CUDA device. kmcudaMemoryCopyError } KMCUDAResult; +/// Centroid initialization method. typedef enum { + /// Pick initial centroids randomly. kmcudaInitMethodRandom = 0, + /// Use kmeans++ initialization method. Theoretically proven to yield + /// better clustering than kmcudaInitMethodRandom. O(n * k) complexity. + /// https://en.wikipedia.org/wiki/K-means%2B%2B kmcudaInitMethodPlusPlus, + /// AFK-MC2 initialization method. Theoretically proven to yield + /// better clustering results than kmcudaInitMethodRandom; matches + /// kmcudaInitMethodPlusPlus asymptotically and fast. O(n + k) complexity. + /// Use it when kmcudaInitMethodPlusPlus takes too long to finish. + /// http://olivierbachem.ch/files/afkmcmc-oral-pdf.pdf kmcudaInitMethodAFKMC2, + /// Take user supplied centroids. kmcudaInitMethodImport } KMCUDAInitMethod; +/// Specifies how to calculate the distance between each pair of dots. typedef enum { + /// Mesasure the distance between dots using Euclidean distance. kmcudaDistanceMetricL2, + /// Measure the distance between dots using the angle between them. + /// @note This metric requires all the supplied data to be normalized by L2 to 1. kmcudaDistanceMetricCosine } KMCUDADistanceMetric; @@ -108,6 +164,7 @@ KMCUDAResult knn_cuda( namespace { namespace kmcuda { +/// Mapping from strings to KMCUDAInitMethod - useful for wrappers. const std::unordered_map init_methods { {"kmeans++", kmcudaInitMethodPlusPlus}, {"k-means++", kmcudaInitMethodPlusPlus}, @@ -116,6 +173,7 @@ const std::unordered_map init_methods { {"random", kmcudaInitMethodRandom} }; +/// Mapping from strings to KMCUDADistanceMetric - useful for wrappers. const std::unordered_map metrics { {"euclidean", kmcudaDistanceMetricL2}, {"L2", kmcudaDistanceMetricL2}, @@ -125,6 +183,7 @@ const std::unordered_map metrics { {"angular", kmcudaDistanceMetricCosine} }; +/// Mapping from KMCUDAResult to strings - useful for wrappers. const std::unordered_map statuses { {kmcudaSuccess, "Success"}, {kmcudaInvalidArguments, "InvalidArguments"}, diff --git a/src/private.h b/src/private.h index 0ea04e2..6a93b81 100644 --- a/src/private.h +++ b/src/private.h @@ -12,12 +12,17 @@ typedef double atomic_float; typedef float atomic_float; #endif +/// printf() under INFO log level (0). #define INFO(...) do { if (verbosity > 0) { printf(__VA_ARGS__); } } while (false) +/// printf() under DEBUG log level (1). #define DEBUG(...) do { if (verbosity > 1) { printf(__VA_ARGS__); } } while (false) +/// printf() under TRACE log level (2). #define TRACE(...) do { if (verbosity > 2) { printf(__VA_ARGS__); } } while (false) #define CUERRSTR() cudaGetErrorString(cudaGetLastError()) +/// Checks the CUDA call for errors, in case of an error logs it and returns. +/// "return" forces this to be a macro. #define CUCH(cuda_call, ret, ...) \ do { \ auto __res = cuda_call; \ @@ -29,6 +34,8 @@ do { \ } \ } while (false) +/// Checks whether the call returns 0; if not, executes arbitrary code and returns. +/// "return" forces this to be a macro. #define RETERR(call, ...) \ do { \ auto __res = call; \ @@ -38,21 +45,26 @@ do { \ } \ } while (false) +/// Executes arbitrary code for every CUDA device. #define FOR_EACH_DEV(...) do { for (int dev : devs) { \ cudaSetDevice(dev); \ __VA_ARGS__; \ } } while(false) +/// Executes arbitrary code for every CUDA device and supplies the device index +/// into the scope. #define FOR_EACH_DEVI(...) do { for (size_t devi = 0; devi < devs.size(); devi++) { \ cudaSetDevice(devs[devi]); \ __VA_ARGS__; \ } } while(false) +/// Invokes cudaDeviceSynchronize() on every CUDA device. #define SYNC_ALL_DEVS do { \ if (devs.size() > 1) { \ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ } } while (false) +/// Copies memory from device to host asynchronously across all the CUDA devices. #define CUMEMCPY_D2H_ASYNC(dst, dst_stride, src, src_offset, size) do { \ FOR_EACH_DEVI(CUCH(cudaMemcpyAsync( \ dst + dst_stride * devi, (src)[devi].get() + src_offset, \ @@ -62,11 +74,13 @@ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ kmcudaMemoryCopyError)); \ } while(false) +/// Copies memory from device to host synchronously across all the CUDA devices. #define CUMEMCPY_D2H(dst, src, size) do { \ CUMEMCPY_D2H_ASYNC(dst, src, size); \ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaMemoryCopyError)); \ } while(false) +/// Copies memory from host to device asynchronously across all the CUDA devices. #define CUMEMCPY_H2D_ASYNC(dst, dst_offset, src, size) do { \ FOR_EACH_DEVI(CUCH(cudaMemcpyAsync( \ (dst)[devi].get() + dst_offset, src, \ @@ -76,11 +90,13 @@ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ kmcudaMemoryCopyError)); \ } while(false) +/// Copies memory from host to device synchronously across all the CUDA devices. #define CUMEMCPY_H2D(dst, src, size) do { \ CUMEMCPY_H2D_ASYNC(dst, src, size); \ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaMemoryCopyError)); \ } while(false) +/// Copies memory from device to device asynchronously across all the CUDA devices. #define CUMEMCPY_D2D_ASYNC(dst, dst_offset, src, src_offset, size) do { \ FOR_EACH_DEVI(CUCH(cudaMemcpyAsync( \ (dst)[devi].get() + dst_offset, (src)[devi].get() + src_offset, \ @@ -90,11 +106,13 @@ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ kmcudaMemoryCopyError)); \ } while(false) +/// Copies memory from device to host synchronously across all the CUDA devices. #define CUMEMCPY_D2D(dst, dst_offset, src, src_offset, size) do { \ CUMEMCPY_D2D_ASYNC(dst, dst_offset, src, src_offset, size); \ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaMemoryCopyError)); \ } while(false) +/// Allocates memory on CUDA device and adds the created pointer to the list. #define CUMALLOC_ONEN(dest, size, name, dev) do { \ void *__ptr; \ size_t __size = (size) * \ @@ -108,14 +126,19 @@ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ reinterpret_cast(__ptr) + __size, __size); \ } while(false) +/// Shortcut for CUMALLOC_ONEN which defines the log name. #define CUMALLOC_ONE(dest, size, dev) CUMALLOC_ONEN(dest, size, #dest, dev) +/// Allocates memory on all CUDA devices. #define CUMALLOCN(dest, size, name) do { \ FOR_EACH_DEV(CUMALLOC_ONEN(dest, size, name, dev)); \ } while(false) +/// Allocates memory on all CUDA devices. Does not require the log name, infers +/// it from dest. #define CUMALLOC(dest, size) CUMALLOCN(dest, size, #dest) +/// Invokes cudaMemsetAsync() on all CUDA devices. #define CUMEMSET_ASYNC(dst, val, size) do { \ FOR_EACH_DEVI(CUCH(cudaMemsetAsync( \ (dst)[devi].get(), val, \ @@ -123,11 +146,13 @@ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ kmcudaRuntimeError)); \ } while(false) +/// Invokes cudaMemset() on all CUDA devices. #define CUMEMSET(dst, val, size) do { \ CUMEMSET_ASYNC(dst, val, size); \ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ } while(false) +/// Executes the specified code on all devices except the given one - number devi. #define FOR_OTHER_DEVS(...) do { \ for (size_t odevi = 0; odevi < devs.size(); odevi++) { \ if (odevi == devi) { \ @@ -136,6 +161,7 @@ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ __VA_ARGS__; \ } } while(false) +/// Copies memory peer to peer (device to other device device). #define CUP2P(what, offset, size) do { \ CUCH(cudaMemcpyPeerAsync( \ (*what)[odevi].get() + offset, devs[odevi], (*what)[devi].get() + offset, \ @@ -145,6 +171,7 @@ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ } while(false) #if CUDA_ARCH >= 60 +/// Bridges the code from single branch to multiple template branches. #define KERNEL_SWITCH(f, ...) do { switch (metric) { \ case kmcudaDistanceMetricL2: \ if (!fp16x2) { \ @@ -178,6 +205,7 @@ FOR_EACH_DEV(CUCH(cudaDeviceSynchronize(), kmcudaRuntimeError)); \ } } while(false) #endif +/// Alternative to dupper() for host. template inline T upper(T size, T each) { T div = size / each; @@ -189,7 +217,14 @@ inline T upper(T size, T each) { using plan_t = std::vector>; -/// @brief (offset, size) pairs. +/// @brief Generates the split across CUDA devices: (offset, size) pairs. +/// It aligns every chunk at 512 bytes without breaking elements. +/// @param amount The total work size - array size in elements. +/// @param size_each Element size in bytes. Thus the total memory size is +/// amount * size_each. +/// @param devs The list with device numbers. +/// @return The list with offset-size pairs. The measurement unit is the element +/// size. inline plan_t distribute( uint32_t amount, uint32_t size_each, const std::vector &devs) { if (devs.size() == 0) { @@ -225,6 +260,8 @@ inline plan_t distribute( return res; } +/// Extracts the maximum split length from the device distribution plan. +/// It calls distribute() and finds the maximum size. inline uint32_t max_distribute_length( uint32_t amount, uint32_t size_each, const std::vector &devs) { auto plan = distribute(amount, size_each, devs); @@ -238,8 +275,8 @@ inline uint32_t max_distribute_length( return max_length; } -inline void print_plan( - const char *name, const std::vector>& plan) { +/// Dumps the device split distribution to stdout. +inline void print_plan(const char *name, const plan_t& plan) { printf("%s: [", name); bool first = true; for (auto& p : plan) { @@ -254,44 +291,54 @@ inline void print_plan( extern "C" { +/// Copies the single sample within the same device. Defined in transpose.cu. KMCUDAResult cuda_copy_sample_t( uint32_t index, uint32_t offset, uint32_t samples_size, uint16_t features_size, const std::vector &devs, int verbosity, const udevptrs &samples, udevptrs *dest); +/// Copies the single sample from device to host. Defined in transpose.cu. KMCUDAResult cuda_extract_sample_t( uint32_t index, uint32_t samples_size, uint16_t features_size, int verbosity, const float *samples, float *dest); +/// Transposes the samples matrix. Defined in transpose.cu. KMCUDAResult cuda_transpose( uint32_t samples_size, uint16_t features_size, bool forward, const std::vector &devs, int verbosity, udevptrs *samples); +/// Invokes kmeans++ kernel. Defined in kmeans.cu. KMCUDAResult kmeans_cuda_plus_plus( uint32_t samples_size, uint32_t features_size, uint32_t cc, KMCUDADistanceMetric metric, const std::vector &devs, int fp16x2, int verbosity, const udevptrs &samples, udevptrs *centroids, udevptrs *dists, float *host_dists, atomic_float *dists_sum); +/// Invokes afk-mc2 kernel "calc_q". Defined in kmeans.cu. KMCUDAResult kmeans_cuda_afkmc2_calc_q( uint32_t samples_size, uint32_t features_size, uint32_t firstc, KMCUDADistanceMetric metric, const std::vector &devs, int fp16x2, int verbosity, const udevptrs &samples, udevptrs *d_q, float *h_q); +/// Invokes afk-mc2 kernel "random_step". Defined in kmeans.cu. KMCUDAResult kmeans_cuda_afkmc2_random_step( uint32_t k, uint32_t m, uint64_t seed, int verbosity, const float *q, uint32_t *d_choices, uint32_t *h_choices, float *d_samples, float *h_samples); +/// Invokes afk-mc2 kernel "min_dist". Defined in kmeans.cu. KMCUDAResult kmeans_cuda_afkmc2_min_dist( uint32_t k, uint32_t m, KMCUDADistanceMetric metric, int fp16x2, int32_t verbosity, const float *samples, const uint32_t *choices, const float *centroids, float *d_min_dists, float *h_min_dists); +/// Initializes the CUDA environment, e.g. assigns values to symbols. +/// Defined in kmeans.cu. KMCUDAResult kmeans_cuda_setup( uint32_t samples_size, uint16_t features_size, uint32_t clusters_size, uint32_t yy_groups_size, const std::vector &devs, int32_t verbosity); +/// Performs the centroids initialization. Defined in kmcuda.cc. KMCUDAResult kmeans_init_centroids( KMCUDAInitMethod method, const void *init_params, uint32_t samples_size, uint16_t features_size, uint32_t clusters_size, KMCUDADistanceMetric metric, @@ -299,6 +346,8 @@ KMCUDAResult kmeans_init_centroids( int32_t verbosity, const float *host_centroids, const udevptrs &samples, udevptrs *dists, udevptrs *aux, udevptrs *centroids); +/// Complementing implementation of kmeans_cuda() which requires nvcc. +/// Defined in kmeans.cu. KMCUDAResult kmeans_cuda_yy( float tolerance, uint32_t yy_groups_size, uint32_t samples_size, uint32_t clusters_size, uint16_t features_size, KMCUDADistanceMetric metric, @@ -309,6 +358,8 @@ KMCUDAResult kmeans_cuda_yy( udevptrs *centroids_yy, udevptrs *bounds_yy, udevptrs *drifts_yy, udevptrs *passed_yy); +/// Calculates the average distance between cluster members and the corresponding +/// centroid. Defined in kmeans.cu. KMCUDAResult kmeans_cuda_calc_average_distance( uint32_t samples_size, uint16_t features_size, KMCUDADistanceMetric metric, const std::vector &devs, int fp16x2, @@ -316,10 +367,14 @@ KMCUDAResult kmeans_cuda_calc_average_distance( const udevptrs ¢roids, const udevptrs &assignments, float *average_distance); +/// Prepares the CUDA environment for K-nn calculation, e.g., assigns values to +/// symbols. Defined in knn.cu. KMCUDAResult knn_cuda_setup( uint32_t samples_size, uint16_t features_size, uint32_t clusters_size, const std::vector &devs, int32_t verbosity); +/// Complementing implementation of knn_cuda() which requires nvcc. +/// Defined in knn.cu. KMCUDAResult knn_cuda_calc( uint16_t k, uint32_t h_samples_size, uint32_t h_clusters_size, uint16_t h_features_size, KMCUDADistanceMetric metric, @@ -330,6 +385,9 @@ KMCUDAResult knn_cuda_calc( udevptrs* sample_dists, udevptrs *radiuses, udevptrs *neighbors); +/// Looks at the amount of available shared memory and decides on the +/// performance critical property of knn_cuda_calc() - which of the two variants +/// to follow. int knn_cuda_neighbors_mem_multiplier(uint16_t k, int dev, int verbosity); } // extern "C" diff --git a/src/tricks.cuh b/src/tricks.cuh index 79f20ec..24b5578 100644 --- a/src/tricks.cuh +++ b/src/tricks.cuh @@ -2,6 +2,9 @@ #define warpSize 32 +/// Inline function which rounds the ratio between size and each to the nearest +/// greater than or equal integer. +/// @param T Any integer type. Calling dupper() on floating point types is useless. template __device__ __forceinline__ T dupper(T size, T each) { T div = size / each; @@ -16,7 +19,8 @@ __device__ __forceinline__ T dmin(T a, T b) { return a <= b? a : b; }*/ -// https://devblogs.nvidia.com/parallelforall/cuda-pro-tip-optimized-filtering-warp-aggregated-atomics/ +/// Optimized aggregation, equivalent to and a drop-in replacement for atomicInc. +/// https://devblogs.nvidia.com/parallelforall/cuda-pro-tip-optimized-filtering-warp-aggregated-atomics/ __device__ __forceinline__ uint32_t atomicAggInc(uint32_t *ctr) { int mask = __ballot(1); int leader = __ffs(mask) - 1; @@ -28,7 +32,8 @@ __device__ __forceinline__ uint32_t atomicAggInc(uint32_t *ctr) { return res + __popc(mask & ((1 << (threadIdx.x % warpSize)) - 1)); } -// https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ +/// Optimized sum reduction, sums all the values across the warp. +/// https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ template __device__ __forceinline__ T warpReduceSum(T val) { #pragma unroll @@ -38,6 +43,7 @@ __device__ __forceinline__ T warpReduceSum(T val) { return val; } +/// Optimized minimum reduction, finds the minimum across the values in the warp. template __device__ __forceinline__ T warpReduceMin(T val) { #pragma unroll @@ -47,7 +53,8 @@ __device__ __forceinline__ T warpReduceMin(T val) { return val; } -// https://github.com/parallel-forall/code-samples/blob/master/posts/cuda-aware-mpi-example/src/Device.cu#L53 +/// This is how would atomicMin() for float-s look like. +/// https://github.com/parallel-forall/code-samples/blob/master/posts/cuda-aware-mpi-example/src/Device.cu#L53 __device__ __forceinline__ void atomicMin( float *const address, const float value) { if (*address <= value) { diff --git a/src/wrappers.h b/src/wrappers.h index 1c57684..cec03e1 100644 --- a/src/wrappers.h +++ b/src/wrappers.h @@ -9,6 +9,10 @@ template using unique_devptr_parent = std::unique_ptr>; +/// RAII for CUDA arrays. Calls cudaFree() on the bound pointer, but only +/// if it is not nullptr (funnily enough, CUDA segfaults otherwise). +/// As can be seen, inherits the rest of the methods from std::unique_ptr. +/// @param T The type of the array element. template class unique_devptr : public unique_devptr_parent { public: @@ -16,6 +20,9 @@ class unique_devptr : public unique_devptr_parent { ptr, fake? [](T*){} : [](T *p){ cudaFree(p); }) {} }; +/// std::vector of unique_devptr-s. Used to pass device arrays inside .cu +/// wrapping functions. +/// @param T The type of the array element. template using udevptrs = std::vector>;