From 7a0316b57c898096b6549ea2feeddc1de0daff77 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 21 Feb 2023 20:46:29 +0800 Subject: [PATCH] feat(udf): minimal Python UDF SDK (#7943) This PR designs a minimal SDK for Python UDFs. Now you can define a function in Python like this: ```python from risingwave.udf import udf, UdfServer @udf(input_types=['INT', 'INT'], result_type='INT') def gcd(x: int, y: int) -> int: while y != 0: (x, y) = (y, x % y) return x if __name__ == '__main__': server = UdfServer() server.add_function(gcd) server.serve() ``` This PR also fixes the problem when functions have no input arguments. Approved-By: xxchan Approved-By: BugenZhao --- .gitignore | 3 + Cargo.lock | 50 ++++---- Cargo.toml | 7 ++ ci/Dockerfile | 5 +- ci/build-ci-image.sh | 2 +- ci/docker-compose.yml | 14 ++- ci/scripts/run-e2e-test.sh | 6 + e2e_test/ddl/function.slt | 27 ----- e2e_test/udf/python.slt | 49 ++++++++ e2e_test/udf/test.py | 29 +++++ src/common/Cargo.toml | 4 +- src/expr/Cargo.toml | 4 +- src/expr/src/expr/expr_udf.rs | 19 +-- src/udf/Cargo.toml | 6 +- src/udf/README.md | 2 +- src/udf/arrow_flight.py | 55 --------- src/udf/examples/client.rs | 43 +++++-- src/udf/python/example.py | 27 +++++ src/udf/python/risingwave/__init__.py | 0 src/udf/python/risingwave/udf.py | 168 ++++++++++++++++++++++++++ src/udf/src/lib.rs | 14 ++- 21 files changed, 386 insertions(+), 148 deletions(-) delete mode 100644 e2e_test/ddl/function.slt create mode 100644 e2e_test/udf/python.slt create mode 100644 e2e_test/udf/test.py delete mode 100644 src/udf/arrow_flight.py create mode 100644 src/udf/python/example.py create mode 100644 src/udf/python/risingwave/__init__.py create mode 100644 src/udf/python/risingwave/udf.py diff --git a/.gitignore b/.gitignore index eb7560717e61..3ebf11628f38 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,9 @@ cmake-build-debug/ *.app build/ +# Python +*.pyc + # Golang go/bin/ diff --git a/Cargo.lock b/Cargo.lock index fd5a01346334..d35714277565 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -134,9 +134,8 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "arrow-array" -version = "31.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e6e839764618a911cc460a58ebee5ad3d42bc12d9a5e96a29b7cc296303aa1" +version = "33.0.0" +source = "git+https://github.com/apache/arrow-rs.git?rev=9a6c516#9a6c516f6e5c5411489a65af2e53dba041a26025" dependencies = [ "ahash 0.8.3", "arrow-buffer", @@ -150,9 +149,8 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "31.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a21d232b1bc1190a3fdd2f9c1e39b7cd41235e95a0d44dd4f522bc5f495748" +version = "33.0.0" +source = "git+https://github.com/apache/arrow-rs.git?rev=9a6c516#9a6c516f6e5c5411489a65af2e53dba041a26025" dependencies = [ "half 2.2.1", "num", @@ -160,9 +158,8 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "31.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83dcdb1436cac574f1c1b30fda91c53c467534337bef4064bbd4ea2d6fbc6e04" +version = "33.0.0" +source = "git+https://github.com/apache/arrow-rs.git?rev=9a6c516#9a6c516f6e5c5411489a65af2e53dba041a26025" dependencies = [ "arrow-array", "arrow-buffer", @@ -176,9 +173,8 @@ dependencies = [ [[package]] name = "arrow-data" -version = "31.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14e3e69c9fd98357eeeab4aa0f626ecf7ecf663e68e8fc04eac87c424a414477" +version = "33.0.0" +source = "git+https://github.com/apache/arrow-rs.git?rev=9a6c516#9a6c516f6e5c5411489a65af2e53dba041a26025" dependencies = [ "arrow-buffer", "arrow-schema", @@ -188,9 +184,8 @@ dependencies = [ [[package]] name = "arrow-flight" -version = "31.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd3ce08d31a1a24497bcf144029f8475539984aa50e41585e01b2057cf3dbb21" +version = "33.0.0" +source = "git+https://github.com/apache/arrow-rs.git?rev=9a6c516#9a6c516f6e5c5411489a65af2e53dba041a26025" dependencies = [ "arrow-array", "arrow-buffer", @@ -211,9 +206,8 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "31.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64cac2706acbd796965b6eaf0da30204fe44aacf70273f8cb3c9b7d7f3d4c190" +version = "33.0.0" +source = "git+https://github.com/apache/arrow-rs.git?rev=9a6c516#9a6c516f6e5c5411489a65af2e53dba041a26025" dependencies = [ "arrow-array", "arrow-buffer", @@ -225,15 +219,13 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "31.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73ca49d010b27e2d73f70c1d1f90c1b378550ed0f4ad379c4dea0c997d97d723" +version = "33.0.0" +source = "git+https://github.com/apache/arrow-rs.git?rev=9a6c516#9a6c516f6e5c5411489a65af2e53dba041a26025" [[package]] name = "arrow-select" -version = "31.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976cbaeb1a85c09eea81f3f9c149c758630ff422ed0238624c5c3f4704b6a53c" +version = "33.0.0" +source = "git+https://github.com/apache/arrow-rs.git?rev=9a6c516#9a6c516f6e5c5411489a65af2e53dba041a26025" dependencies = [ "arrow-array", "arrow-buffer", @@ -2291,12 +2283,12 @@ checksum = "cda653ca797810c02f7ca4b804b40b8b95ae046eb989d356bce17919a8c25499" [[package]] name = "flatbuffers" -version = "22.9.29" +version = "23.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce016b9901aef3579617931fbb2df8fc9a9f7cb95a16eb8acc8148209bb9e70" +checksum = "77f5399c2c9c50ae9418e522842ad362f61ee48b346ac106807bd355a8a7c619" dependencies = [ "bitflags", - "thiserror", + "rustc_version", ] [[package]] @@ -4847,9 +4839,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.49" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" dependencies = [ "unicode-ident", ] diff --git a/Cargo.toml b/Cargo.toml index a033e61dab5f..e7fc277f3fa0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -123,3 +123,10 @@ tokio-stream = { git = "https://github.com/madsim-rs/tokio.git", rev = "0c25710" tokio-retry = { git = "https://github.com/madsim-rs/rust-tokio-retry.git", rev = "95e2fd3" } tokio-postgres = { git = "https://github.com/madsim-rs/rust-postgres.git", rev = "87ca1dc" } postgres-types = { git = "https://github.com/madsim-rs/rust-postgres.git", rev = "87ca1dc" } + +# TODO: remove these patches when arrow releases v34 +# we need this commit for handling batch with no columns +# https://github.com/apache/arrow-rs/commit/ea48b9571f88bfbced60f9790ae2a7102502870e +arrow-array = { git = "https://github.com/apache/arrow-rs.git", rev = "9a6c516" } +arrow-schema = { git = "https://github.com/apache/arrow-rs.git", rev = "9a6c516" } +arrow-flight = { git = "https://github.com/apache/arrow-rs.git", rev = "9a6c516" } diff --git a/ci/Dockerfile b/ci/Dockerfile index da9ba9dc5fae..cf19293e861c 100644 --- a/ci/Dockerfile +++ b/ci/Dockerfile @@ -5,7 +5,7 @@ ENV LANG en_US.utf8 ARG RUST_TOOLCHAIN RUN apt-get update -yy && \ - DEBIAN_FRONTEND=noninteractive apt-get -y install make build-essential cmake protobuf-compiler curl parallel \ + DEBIAN_FRONTEND=noninteractive apt-get -y install make build-essential cmake protobuf-compiler curl parallel python3 python3-pip \ openssl libssl-dev libsasl2-dev libcurl4-openssl-dev pkg-config bash openjdk-11-jdk wget unzip git tmux lld postgresql-client kafkacat netcat mysql-client \ maven -yy \ && rm -rf /var/lib/{apt,dpkg,cache,log}/ @@ -23,6 +23,9 @@ WORKDIR /risingwave ENV PATH /root/.cargo/bin/:$PATH +# install python dependencies +RUN pip3 install pyarrow + # add required rustup components RUN rustup component add rustfmt llvm-tools-preview clippy diff --git a/ci/build-ci-image.sh b/ci/build-ci-image.sh index 875c20a8d764..ffb8b0fbe8d0 100755 --- a/ci/build-ci-image.sh +++ b/ci/build-ci-image.sh @@ -14,7 +14,7 @@ export RUST_TOOLCHAIN=$(cat ../rust-toolchain) # !!! CHANGE THIS WHEN YOU WANT TO BUMP CI IMAGE !!! # # AND ALSO docker-compose.yml # ###################################################### -export BUILD_ENV_VERSION=v20230220 +export BUILD_ENV_VERSION=v20230221_01 export BUILD_TAG="public.ecr.aws/x5u3w5h6/rw-build-env:${BUILD_ENV_VERSION}" diff --git a/ci/docker-compose.yml b/ci/docker-compose.yml index d33b7163f372..6b4493a0fd85 100644 --- a/ci/docker-compose.yml +++ b/ci/docker-compose.yml @@ -24,13 +24,17 @@ services: - MYSQL_USER=mysqluser - MYSQL_PASSWORD=mysqlpw healthcheck: - test: [ "CMD-SHELL", "mysqladmin ping -h 127.0.0.1 -u root -p123456" ] + test: + [ + "CMD-SHELL", + "mysqladmin ping -h 127.0.0.1 -u root -p123456" + ] interval: 5s timeout: 5s retries: 5 source-test-env: - image: public.ecr.aws/x5u3w5h6/rw-build-env:v20230220 + image: public.ecr.aws/x5u3w5h6/rw-build-env:v20230221_01 depends_on: - mysql - db @@ -38,7 +42,7 @@ services: - ..:/risingwave sink-test-env: - image: public.ecr.aws/x5u3w5h6/rw-build-env:v20230220 + image: public.ecr.aws/x5u3w5h6/rw-build-env:v20230221_01 depends_on: - mysql - db @@ -46,12 +50,12 @@ services: - ..:/risingwave rw-build-env: - image: public.ecr.aws/x5u3w5h6/rw-build-env:v20230220 + image: public.ecr.aws/x5u3w5h6/rw-build-env:v20230221_01 volumes: - ..:/risingwave regress-test-env: - image: public.ecr.aws/x5u3w5h6/rw-build-env:v20230220 + image: public.ecr.aws/x5u3w5h6/rw-build-env:v20230221_01 depends_on: db: condition: service_healthy diff --git a/ci/scripts/run-e2e-test.sh b/ci/scripts/run-e2e-test.sh index da50b2656604..93dc93447de7 100755 --- a/ci/scripts/run-e2e-test.sh +++ b/ci/scripts/run-e2e-test.sh @@ -55,6 +55,12 @@ sqllogictest -p 4566 -d dev './e2e_test/batch/**/*.slt' --junit "batch-${profile sqllogictest -p 4566 -d dev './e2e_test/database/prepare.slt' sqllogictest -p 4566 -d test './e2e_test/database/test.slt' +echo "--- e2e, ci-3cn-1fe, udf" +python3 e2e_test/udf/test.py & +sleep 2 +sqllogictest -p 4566 -d dev './e2e_test/udf/python.slt' +pkill python3 + echo "--- Kill cluster" cargo make ci-kill diff --git a/e2e_test/ddl/function.slt b/e2e_test/ddl/function.slt deleted file mode 100644 index 5f7dea9487f7..000000000000 --- a/e2e_test/ddl/function.slt +++ /dev/null @@ -1,27 +0,0 @@ -# TODO: check the service on creation - -# Create a function. -statement ok -create function func(int, int) returns int as 'http://localhost:8815' language arrow_flight; - -# Create a function with the same name but different arguments. -statement ok -create function func(int) returns int as 'http://localhost:8815' language arrow_flight; - -# Create a function with the same name and arguments. -statement error exists -create function func(int) returns int as 'http://localhost:8815' language arrow_flight; - -# TODO: drop function without arguments - -# # Drop a function but ambiguous. -# statement error is not unique -# drop function func; - -# Drop a function -statement ok -drop function func(int); - -# Drop a function -statement ok -drop function func(int, int); diff --git a/e2e_test/udf/python.slt b/e2e_test/udf/python.slt new file mode 100644 index 000000000000..405ac555cb87 --- /dev/null +++ b/e2e_test/udf/python.slt @@ -0,0 +1,49 @@ +# Before running this test: +# python3 e2e_test/udf/test.py + +# TODO: check the service on creation +# Currently whether the function exists in backend and whether the signature matches is checked on execution. Create function will always succeed. + +# Create a function. +statement ok +create function int_42() returns int as 'http://localhost:8815' language arrow_flight; + +statement ok +create function gcd(int, int) returns int as 'http://localhost:8815' language arrow_flight; + +# Create a function with the same name but different arguments. +statement ok +create function gcd(int, int, int) returns int as 'http://localhost:8815' language arrow_flight; + +# Create a function with the same name and arguments. +statement error exists +create function gcd(int, int) returns int as 'http://localhost:8815' language arrow_flight; + +query I +select int_42(); +---- +42 + +query I +select gcd(25, 15); +---- +5 + +query I +select gcd(25, 15, 3); +---- +1 + +# TODO: drop function without arguments + +# # Drop a function but ambiguous. +# statement error is not unique +# drop function gcd; + +# Drop a function +statement ok +drop function gcd(int, int); + +# Drop a function +statement ok +drop function gcd(int, int, int); diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py new file mode 100644 index 000000000000..ddb907629c12 --- /dev/null +++ b/e2e_test/udf/test.py @@ -0,0 +1,29 @@ +import sys +sys.path.append('src/udf/python') # noqa + +from risingwave.udf import udf, UdfServer + + +@udf(input_types=[], result_type='INT') +def int_42() -> int: + return 42 + + +@udf(input_types=['INT', 'INT'], result_type='INT') +def gcd(x: int, y: int) -> int: + while y != 0: + (x, y) = (y, x % y) + return x + + +@udf(name='gcd', input_types=['INT', 'INT', 'INT'], result_type='INT') +def gcd3(x: int, y: int, z: int) -> int: + return gcd(gcd(x, y), z) + + +if __name__ == '__main__': + server = UdfServer() + server.add_function(int_42) + server.add_function(gcd) + server.add_function(gcd3) + server.serve() diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index fe3478820949..c6a3fe1393e4 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -15,8 +15,8 @@ normal = ["workspace-hack"] [dependencies] anyhow = "1" -arrow-array = "31" -arrow-schema = "31" +arrow-array = "33" +arrow-schema = "33" async-trait = "0.1" auto_enums = "0.7" bitflags = "1.3.2" diff --git a/src/expr/Cargo.toml b/src/expr/Cargo.toml index cc09f522ca9d..bc2d71a93b1d 100644 --- a/src/expr/Cargo.toml +++ b/src/expr/Cargo.toml @@ -17,8 +17,8 @@ normal = ["workspace-hack"] [dependencies] aho-corasick = "0.7" anyhow = "1" -arrow-array = "31" -arrow-schema = "31" +arrow-array = "33" +arrow-schema = "33" chrono = { version = "0.4", default-features = false, features = ["clock", "std"] } chrono-tz = { version = "0.7", features = ["case-insensitive"] } dyn-clone = "1" diff --git a/src/expr/src/expr/expr_udf.rs b/src/expr/src/expr/expr_udf.rs index b64f01add35c..2726a0b75376 100644 --- a/src/expr/src/expr/expr_udf.rs +++ b/src/expr/src/expr/expr_udf.rs @@ -15,7 +15,7 @@ use std::convert::TryFrom; use std::sync::Arc; -use arrow_schema::{Field, Schema}; +use arrow_schema::{Field, Schema, SchemaRef}; use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk}; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; @@ -33,6 +33,7 @@ pub struct UdfExpression { // name: String, arg_types: Vec, return_type: DataType, + arg_schema: SchemaRef, client: ArrowFlightUdfClient, function_id: FunctionId, } @@ -51,10 +52,13 @@ impl Expression for UdfExpression { let columns: Vec<_> = self .children .iter() - .map(|c| c.eval_checked(input).map(|a| ("", a.as_ref().into()))) + .map(|c| c.eval_checked(input).map(|a| a.as_ref().into())) .try_collect()?; + let opts = + arrow_array::RecordBatchOptions::default().with_row_count(Some(input.cardinality())); let input = - arrow_array::RecordBatch::try_from_iter(columns).expect("failed to build record batch"); + arrow_array::RecordBatch::try_new_with_options(self.arg_schema.clone(), columns, &opts) + .expect("failed to build record batch"); let output = tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(self.client.call(&self.function_id, input)) })?; @@ -87,18 +91,18 @@ impl<'a> TryFrom<&'a ExprNode> for UdfExpression { bail!("expect UDF"); }; // connect to UDF service and check the function - let (client, function_id) = tokio::task::block_in_place(|| { + let (client, function_id, arg_schema) = tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { let client = ArrowFlightUdfClient::connect(&udf.path).await?; - let args = Schema::new( + let args = Arc::new(Schema::new( udf.arg_types .iter() .map(|t| Field::new("", DataType::from(t).into(), true)) .collect(), - ); + )); let returns = Schema::new(vec![Field::new("", (&return_type).into(), true)]); let id = client.check(&udf.name, &args, &returns).await?; - Ok((client, id)) as risingwave_udf::Result<_> + Ok((client, id, args)) as risingwave_udf::Result<_> }) })?; Ok(Self { @@ -106,6 +110,7 @@ impl<'a> TryFrom<&'a ExprNode> for UdfExpression { // name: udf.name.clone(), arg_types: udf.arg_types.iter().map(|t| t.into()).collect(), return_type, + arg_schema, client, function_id, }) diff --git a/src/udf/Cargo.toml b/src/udf/Cargo.toml index 951a93ced41d..36820820a618 100644 --- a/src/udf/Cargo.toml +++ b/src/udf/Cargo.toml @@ -11,9 +11,9 @@ ignored = ["workspace-hack"] normal = ["workspace-hack"] [dependencies] -arrow-array = "31" -arrow-flight = "31" -arrow-schema = "31" +arrow-array = "33" +arrow-flight = "33" +arrow-schema = "33" futures-util = "0.3.25" thiserror = "1" tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "macros"] } diff --git a/src/udf/README.md b/src/udf/README.md index ee41f396548b..6364e8513a05 100644 --- a/src/udf/README.md +++ b/src/udf/README.md @@ -7,7 +7,7 @@ ```sh pip3 install pyarrow # run server -python3 arrow_flight.py +python3 python/example.py # run client (test client for the arrow flight UDF client-server protocol) cargo run --example client ``` diff --git a/src/udf/arrow_flight.py b/src/udf/arrow_flight.py deleted file mode 100644 index f464a12332ad..000000000000 --- a/src/udf/arrow_flight.py +++ /dev/null @@ -1,55 +0,0 @@ -import pathlib - -import pyarrow as pa -import pyarrow.flight -import pyarrow.parquet - - -class FlightServer(pa.flight.FlightServerBase): - """ - Reference: https://arrow.apache.org/cookbook/py/flight.html#simple-parquet-storage-service-with-arrow-flight - """ - - def __init__(self, location="grpc://0.0.0.0:8815", **kwargs): - super(FlightServer, self).__init__(location, **kwargs) - self._location = location - self._functions = {} - - def get_flight_info(self, context, descriptor): - """Return a FlightInfo.""" - return pa.flight.FlightInfo(schema=pa.schema([ - ('c', pa.int32()), - ]), descriptor=descriptor, endpoints=[], total_records=0, total_bytes=0) - - def add_function(self, name: str, func): - """Add a function to the server.""" - self._functions[name] = func - - def do_exchange(self, context, descriptor, reader, writer): - """Run a simple echo server.""" - func = self._functions[descriptor.path[0].decode('utf-8')] - schema = pa.schema([ - ('c', pa.int32()), - ]) - writer.begin(schema) - for chunk in reader: - print(pa.Table.from_batches([chunk.data])) - result = self.call_func(func, chunk.data) - writer.write_table(result) - - def call_func(self, func, batch: pa.RecordBatch) -> pa.Table: - data = pa.array([func(batch[0][i].as_py(), batch[1][i].as_py()) - for i in range(len(batch[0]))]) - return pa.Table.from_arrays([data], names=['c']) - - -def gcd(x: int, y: int) -> int: - while y != 0: - (x, y) = (y, x % y) - return x - - -if __name__ == '__main__': - server = FlightServer() - server.add_function("gcd", gcd) - server.serve() diff --git a/src/udf/examples/client.rs b/src/udf/examples/client.rs index 034d3816eab4..415e00b3151e 100644 --- a/src/udf/examples/client.rs +++ b/src/udf/examples/client.rs @@ -24,28 +24,51 @@ async fn main() { let client = ArrowFlightUdfClient::connect(addr).await.unwrap(); // build `RecordBatch` to send (equivalent to our `DataChunk`) - let array1 = Int32Array::from_iter(vec![1, 6, 10]); - let array2 = Int32Array::from_iter(vec![3, 4, 15]); - let input_schema = Schema::new(vec![ + let array1 = Arc::new(Int32Array::from_iter(vec![1, 6, 10])); + let array2 = Arc::new(Int32Array::from_iter(vec![3, 4, 15])); + let array3 = Arc::new(Int32Array::from_iter(vec![6, 8, 3])); + let input2_schema = Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), ]); - let output_schema = Schema::new(vec![Field::new("c", DataType::Int32, true)]); + let input3_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let output_schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]); // check function - let id = client - .check("gcd", &input_schema, &output_schema) + let gcd2 = client + .check("gcd", &input2_schema, &output_schema) .await .unwrap(); + let gcd3 = client + .check("gcd", &input3_schema, &output_schema) + .await + .unwrap(); + + let input2 = RecordBatch::try_new( + Arc::new(input2_schema), + vec![array1.clone(), array2.clone()], + ) + .unwrap(); + + let output = client + .call(&gcd2, input2) + .await + .expect("failed to call function"); + + println!("{:?}", output); - let input = RecordBatch::try_new( - Arc::new(input_schema), - vec![Arc::new(array1), Arc::new(array2)], + let input3 = RecordBatch::try_new( + Arc::new(input3_schema), + vec![array1.clone(), array2.clone(), array3.clone()], ) .unwrap(); let output = client - .call(&id, input) + .call(&gcd3, input3) .await .expect("failed to call function"); diff --git a/src/udf/python/example.py b/src/udf/python/example.py new file mode 100644 index 000000000000..1b615728fe01 --- /dev/null +++ b/src/udf/python/example.py @@ -0,0 +1,27 @@ +from risingwave.udf import udf, UdfServer +import random + + +@udf(input_types=[], result_type='INT') +def random_int() -> int: + return random.randint(0, 100) + + +@udf(input_types=['INT', 'INT'], result_type='INT') +def gcd(x: int, y: int) -> int: + while y != 0: + (x, y) = (y, x % y) + return x + + +@udf(name='gcd', input_types=['INT', 'INT', 'INT'], result_type='INT') +def gcd3(x: int, y: int, z: int) -> int: + return gcd(gcd(x, y), z) + + +if __name__ == '__main__': + server = UdfServer() + server.add_function(random_int) + server.add_function(gcd) + server.add_function(gcd3) + server.serve() diff --git a/src/udf/python/risingwave/__init__.py b/src/udf/python/risingwave/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/udf/python/risingwave/udf.py b/src/udf/python/risingwave/udf.py new file mode 100644 index 000000000000..bdce30a56387 --- /dev/null +++ b/src/udf/python/risingwave/udf.py @@ -0,0 +1,168 @@ +from typing import * +import pyarrow as pa +import pyarrow.flight +import pyarrow.parquet + + +class UserDefinedFunction: + """ + Base interface for user-defined function. + """ + _name: str + _input_types: List[pa.DataType] + _result_type: pa.DataType + + def full_name(self) -> str: + """ + A unique name for the function. Composed by function name and input types. + Example: "gcd/int32,int32" + """ + return self._name + '/' + ','.join([str(t) for t in self._input_types]) + + def result_schema(self) -> pa.Schema: + """ + Returns the schema of the result table. + """ + return pa.schema([('', self._result_type)]) + + def eval_batch(self, batch: pa.RecordBatch) -> pa.RecordBatch: + """ + Apply the function on a batch of inputs. + """ + pass + + +class ScalarFunction(UserDefinedFunction): + """ + Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one, + or multiple scalar values to a new scalar value. + """ + + def eval(self, *args): + """ + Method which defines the logic of the scalar function. + """ + pass + + def eval_batch(self, batch: pa.RecordBatch) -> pa.RecordBatch: + result = pa.array([self.eval(*[col[i].as_py() for col in batch]) + for i in range(batch.num_rows)], + type=self._result_type) + return pa.RecordBatch.from_arrays([result], schema=self.result_schema()) + + +class UserDefinedFunctionWrapper(ScalarFunction): + """ + Base Wrapper for Python user-defined function. + """ + _func: Callable + + def __init__(self, func, input_types, result_type, name=None): + self._func = func + self._input_types = [_to_data_type(t) for t in input_types] + self._result_type = _to_data_type(result_type) + self._name = name or ( + func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) + + def __call__(self, *args): + return self._func(*args) + + def eval(self, *args): + return self._func(*args) + + +def _create_udf(f, input_types, result_type, name): + return UserDefinedFunctionWrapper( + f, input_types, result_type, name) + + +def udf(input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], + result_type: Union[str, pa.DataType], + name: Optional[str] = None,) -> Union[Callable, UserDefinedFunction]: + """ + Annotation for creating a user-defined function. + """ + + return lambda f: _create_udf(f, input_types, result_type, name) + + +class UdfServer(pa.flight.FlightServerBase): + """ + UDF server based on Apache Arrow Flight protocol. + Reference: https://arrow.apache.org/cookbook/py/flight.html#simple-parquet-storage-service-with-arrow-flight + """ + _functions: Dict[str, UserDefinedFunction] + + def __init__(self, location="grpc://0.0.0.0:8815", **kwargs): + super(UdfServer, self).__init__(location, **kwargs) + self._functions = {} + + def get_flight_info(self, context, descriptor): + """Return the result schema of a function.""" + udf = self._functions[descriptor.path[0].decode('utf-8')] + return pa.flight.FlightInfo(schema=udf.result_schema(), descriptor=descriptor, endpoints=[], total_records=0, total_bytes=0) + + def add_function(self, udf: UserDefinedFunction): + """Add a function to the server.""" + name = udf.full_name() + if name in self._functions: + raise ValueError('Function already exists: ' + name) + print('added function:', name) + self._functions[name] = udf + + def do_exchange(self, context, descriptor, reader, writer): + """Call a function from the client.""" + udf = self._functions[descriptor.path[0].decode('utf-8')] + writer.begin(udf.result_schema()) + for chunk in reader: + # print(pa.Table.from_batches([chunk.data])) + result = udf.eval_batch(chunk.data) + writer.write_batch(result) + + def serve(self): + """Start the server.""" + super(UdfServer, self).serve() + + +def _to_data_type(t: Union[str, pa.DataType]) -> pa.DataType: + """ + Convert a string or pyarrow.DataType to pyarrow.DataType. + """ + if isinstance(t, str): + return _string_to_data_type(t) + else: + return t + + +def _string_to_data_type(type_str: str): + match type_str: + case 'BOOLEAN': + return pa.bool_() + case 'TINYINT': + return pa.int8() + case 'SMALLINT': + return pa.int16() + case 'INT' | 'INTEGER': + return pa.int32() + case 'BIGINT': + return pa.int64() + case 'FLOAT' | 'REAL': + return pa.float32() + case 'DOUBLE': + return pa.float64() + case 'DECIMAL': + return pa.decimal128(38) + case 'DATE': + return pa.date32() + case 'DATETIME': + return pa.timestamp('ms') + case 'TIME': + return pa.time32('ms') + case 'TIMESTAMP': + return pa.timestamp('us') + case 'CHAR' | 'VARCHAR': + return pa.string() + case 'BINARY' | 'VARBINARY': + return pa.binary() + case _: + raise ValueError(f'Unsupported type: {type_str}') diff --git a/src/udf/src/lib.rs b/src/udf/src/lib.rs index 6d8a2bb8f325..4a4d48ab7ef6 100644 --- a/src/udf/src/lib.rs +++ b/src/udf/src/lib.rs @@ -38,11 +38,15 @@ impl ArrowFlightUdfClient { /// Check if the function is available and return the function ID. pub async fn check(&self, name: &str, args: &Schema, returns: &Schema) -> Result { - let mut path = vec![name.to_string()]; - for arg in &args.fields { - path.push(format!("{}", arg.data_type())); + // path = name/[args,]* + let mut path = name.to_string() + "/"; + for (i, arg) in args.fields.iter().enumerate() { + if i != 0 { + path += ","; + } + path += &arg.data_type().to_string().to_lowercase(); } - let descriptor = FlightDescriptor::new_path(path.clone()); + let descriptor = FlightDescriptor::new_path(vec![path.clone()]); let response = self.client.clone().get_flight_info(descriptor).await?; @@ -58,7 +62,7 @@ impl ArrowFlightUdfClient { actual: format!("{:?}", actual_types), }); } - Ok(FunctionId(path)) + Ok(FunctionId(vec![path])) } /// Call a function.