forked from apache/arrow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple_udf.rs
138 lines (116 loc) · 5.03 KB
/
simple_udf.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use arrow::{
array::{Array, ArrayRef, Float32Array, Float64Array, Float64Builder},
datatypes::DataType,
record_batch::RecordBatch,
util::pretty,
};
use datafusion::error::Result;
use datafusion::{physical_plan::functions::ScalarFunctionImplementation, prelude::*};
use std::sync::Arc;
// create local execution context with an in-memory table
fn create_context() -> Result<ExecutionContext> {
use arrow::datatypes::{Field, Schema};
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float64, false),
]));
// define data.
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])),
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
],
)?;
// declare a new context. In spark API, this corresponds to a new spark SQLsession
let mut ctx = ExecutionContext::new();
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::new(schema, vec![vec![batch]])?;
ctx.register_table("t", Box::new(provider));
Ok(ctx)
}
/// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b
fn main() -> Result<()> {
let mut ctx = create_context()?;
// First, declare the actual implementation of the calculation
let pow: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| {
// in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to:
// 1. cast the values to the type we want
// 2. perform the computation for every element in the array (using a loop or SIMD)
// 3. construct the resulting array
// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);
// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
let base = &args[0]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
let exponent = &args[1]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
// this is guaranteed by DataFusion. We place it just to make it obvious.
assert_eq!(exponent.len(), base.len());
// 2. Arrow's builder is used to construct an Arrow array.
let mut builder = Float64Builder::new(base.len());
for index in 0..base.len() {
// in arrow, any value can be null.
// Here we decide to make our UDF to return null when either base or exponent is null.
if base.is_null(index) || exponent.is_null(index) {
builder.append_null()?;
} else {
// 3. computation. Since we do not have any SIMD `pow` operation at our hands,
// we loop over each entry. Array's values are obtained via `.value(index)`.
let value = base.value(index).powf(exponent.value(index));
builder.append_value(value)?;
}
}
Ok(Arc::new(builder.finish()))
});
// Next:
// * give it a name so that it shows nicely when the plan is printed
// * declare what input it expects
// * declare its return type
let pow = create_udf(
"pow",
// expects two f64
vec![DataType::Float64, DataType::Float64],
// returns f64
Arc::new(DataType::Float64),
pow,
);
// finally, register the UDF
ctx.register_udf(pow);
// at this point, we can use it. Note that the code below can be in a
// scope on which we do not have access to `pow`.
// get a DataFrame from the context
let df = ctx.table("t")?;
// get the udf registry.
let f = df.registry();
// equivalent to `'SELECT pow(a, b) FROM t'`
let df = df.select(vec![f.udf("pow", vec![col("a"), col("b")])?])?;
// note that "b" is f32, not f64. DataFusion coerces the types to match the UDF's signature.
// execute the query
let results = df.collect()?;
// print the results
pretty::print_batches(&results)?;
Ok(())
}