From eeac48961b901ec28adc9a78a3fc63d0e4f5d92e Mon Sep 17 00:00:00 2001 From: Wenlei Xie Date: Wed, 22 Jun 2022 23:17:40 -0700 Subject: [PATCH] Add clamp_list UDF (#397) Summary: Pull Request resolved: https://github.com/pytorch/torcharrow/pull/397 Basically apply clamp to each element in the list. Used for sparse feature preproc in recommendation domain. This UDF will be deprecated once TorchArrow supports lambda function. Reviewed By: bearzx Differential Revision: D37370459 fbshipit-source-id: f64d59129dd980f886241ce1ac8cfaee251c0a01 --- csrc/velox/functions/CMakeLists.txt | 5 +- csrc/velox/functions/functions.h | 29 +++++++++++ csrc/velox/functions/rec/clamp_list.h | 37 ++++++++++++++ torcharrow/test/transformation/test_clamp.py | 53 ++++++++++++++++++++ 4 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 csrc/velox/functions/rec/clamp_list.h create mode 100644 torcharrow/test/transformation/test_clamp.py diff --git a/csrc/velox/functions/CMakeLists.txt b/csrc/velox/functions/CMakeLists.txt index 0c4012456..c1931a13f 100644 --- a/csrc/velox/functions/CMakeLists.txt +++ b/csrc/velox/functions/CMakeLists.txt @@ -13,14 +13,15 @@ set( rec/sigrid_hash.h rec/firstX.h rec/compute_score.h + rec/clamp_list.h register_udf.cpp - ) +) set( TORCHARROW_UDF_LINK_LIBRARIES velox_functions_string velox_functions_prestosql - ) +) set(TORCHARROW_UDF_COMPILE_DEFINITIONS) if (USE_TORCH) diff --git a/csrc/velox/functions/functions.h b/csrc/velox/functions/functions.h index ba7b99627..f0ecddecb 100644 --- a/csrc/velox/functions/functions.h +++ b/csrc/velox/functions/functions.h @@ -11,6 +11,7 @@ #include #include "numeric_functions.h" #include "rec/bucketize.h" // @manual +#include "rec/clamp_list.h" // @manual #include "rec/compute_score.h" // @manual #include "rec/firstX.h" // @manual #include "rec/sigrid_hash.h" // @manual @@ -368,6 +369,34 @@ inline void registerTorchArrowFunctions() { velox::Array, velox::Array>({"get_score_max"}); + velox::registerFunction< + ClampListFunction, + velox::Array, + velox::Array, + int32_t, + int32_t>({"clamp_list"}); + + velox::registerFunction< + ClampListFunction, + velox::Array, + velox::Array, + int64_t, + int64_t>({"clamp_list"}); + + velox::registerFunction< + ClampListFunction, + velox::Array, + velox::Array, + float, + float>({"clamp_list"}); + + velox::registerFunction< + ClampListFunction, + velox::Array, + velox::Array, + double, + double>({"clamp_list"}); + // TODO: consider to refactor registration code with helper functions // to save some lines, like https://fburl.com/code/dk6zi7t3 diff --git a/csrc/velox/functions/rec/clamp_list.h b/csrc/velox/functions/rec/clamp_list.h new file mode 100644 index 000000000..56b41bc79 --- /dev/null +++ b/csrc/velox/functions/rec/clamp_list.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include "velox/functions/Udf.h" +#include "velox/type/Type.h" + +namespace facebook::torcharrow::functions { + +// TODO: remove this function once lambda expression is supported in +// TorchArrow +template +struct ClampListFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + template + FOLLY_ALWAYS_INLINE void callNullFree( + TOutput& result, + const TInput& values, + const TElement& lo, + const TElement& hi) { + VELOX_USER_CHECK_LE(lo, hi, "Lo > hi in clamp."); + result.reserve(values.size()); + for (const auto& val : values) { + result.push_back(std::clamp(val, lo, hi)); + } + } +}; + +} // namespace facebook::torcharrow::functions diff --git a/torcharrow/test/transformation/test_clamp.py b/torcharrow/test/transformation/test_clamp.py new file mode 100644 index 000000000..2a534b9be --- /dev/null +++ b/torcharrow/test/transformation/test_clamp.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torcharrow as ta +import torcharrow.dtypes as dt +from torcharrow import functional + + +class _TestClampBase(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_df_list = ta.dataframe( + { + "int64": [[0, 1, 2, 3], [-100, 100, 10], [0, -1, -2, -3]], + }, + dtype=dt.Struct( + fields=[ + dt.Field("int64", dt.List(dt.int64)), + ] + ), + ) + + cls.setUpTestCaseData() + + @classmethod + def setUpTestCaseData(cls): + # Override in subclass + # Python doesn't have native "abstract base test" support. + # So use unittest.SkipTest to skip in base class: https://stackoverflow.com/a/59561905. + raise unittest.SkipTest("abstract base test") + + def test_clamp_list(self): + df = type(self).df_list + + self.assertEqual( + list(functional.clamp_list(df["int64"], 0, 20)), + [[0, 1, 2, 3], [0, 20, 10], [0, 0, 0, 0]], + ) + + +class TestClampCpu(_TestClampBase): + @classmethod + def setUpTestCaseData(cls): + cls.df_list = cls.base_df_list.copy() + + +if __name__ == "__main__": + unittest.main()