Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions csrc/velox/functions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ set(
rec/sigrid_hash.h
rec/firstX.h
rec/compute_score.h
rec/clamp_list.h
register_udf.cpp
)
)

set(
TORCHARROW_UDF_LINK_LIBRARIES
velox_functions_string
velox_functions_prestosql
)
)
set(TORCHARROW_UDF_COMPILE_DEFINITIONS)

if (USE_TORCH)
Expand Down
29 changes: 29 additions & 0 deletions csrc/velox/functions/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <velox/functions/Registerer.h>
#include "numeric_functions.h"
#include "rec/bucketize.h" // @manual
#include "rec/clamp_list.h" // @manual
#include "rec/compute_score.h" // @manual
#include "rec/firstX.h" // @manual
#include "rec/sigrid_hash.h" // @manual
Expand Down Expand Up @@ -368,6 +369,34 @@ inline void registerTorchArrowFunctions() {
velox::Array<int64_t>,
velox::Array<float>>({"get_score_max"});

velox::registerFunction<
ClampListFunction,
velox::Array<int32_t>,
velox::Array<int32_t>,
int32_t,
int32_t>({"clamp_list"});

velox::registerFunction<
ClampListFunction,
velox::Array<int64_t>,
velox::Array<int64_t>,
int64_t,
int64_t>({"clamp_list"});

velox::registerFunction<
ClampListFunction,
velox::Array<float>,
velox::Array<float>,
float,
float>({"clamp_list"});

velox::registerFunction<
ClampListFunction,
velox::Array<double>,
velox::Array<double>,
double,
double>({"clamp_list"});

// TODO: consider to refactor registration code with helper functions
// to save some lines, like https://fburl.com/code/dk6zi7t3

Expand Down
37 changes: 37 additions & 0 deletions csrc/velox/functions/rec/clamp_list.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cmath>
#include "velox/functions/Udf.h"
#include "velox/type/Type.h"

namespace facebook::torcharrow::functions {

// TODO: remove this function once lambda expression is supported in
// TorchArrow
template <typename T>
struct ClampListFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

template <typename TOutput, typename TInput, typename TElement>
FOLLY_ALWAYS_INLINE void callNullFree(
TOutput& result,
const TInput& values,
const TElement& lo,
const TElement& hi) {
VELOX_USER_CHECK_LE(lo, hi, "Lo > hi in clamp.");
result.reserve(values.size());
for (const auto& val : values) {
result.push_back(std::clamp(val, lo, hi));
}
}
};

} // namespace facebook::torcharrow::functions
53 changes: 53 additions & 0 deletions torcharrow/test/transformation/test_clamp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torcharrow as ta
import torcharrow.dtypes as dt
from torcharrow import functional


class _TestClampBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_df_list = ta.dataframe(
{
"int64": [[0, 1, 2, 3], [-100, 100, 10], [0, -1, -2, -3]],
},
dtype=dt.Struct(
fields=[
dt.Field("int64", dt.List(dt.int64)),
]
),
)

cls.setUpTestCaseData()

@classmethod
def setUpTestCaseData(cls):
# Override in subclass
# Python doesn't have native "abstract base test" support.
# So use unittest.SkipTest to skip in base class: https://stackoverflow.com/a/59561905.
raise unittest.SkipTest("abstract base test")

def test_clamp_list(self):
df = type(self).df_list

self.assertEqual(
list(functional.clamp_list(df["int64"], 0, 20)),
[[0, 1, 2, 3], [0, 20, 10], [0, 0, 0, 0]],
)


class TestClampCpu(_TestClampBase):
@classmethod
def setUpTestCaseData(cls):
cls.df_list = cls.base_df_list.copy()


if __name__ == "__main__":
unittest.main()