Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 95 additions & 25 deletions examples/sql-using-python-udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,50 +15,120 @@
# specific language governing permissions and limitations
# under the License.

"""
Example demonstrating how to use Python User Defined Functions (UDFs) with DataFusion.
This example shows how to:
1. Create and register a UDF
2. Create a table with test data
3. Execute a SQL query using the UDF
4. Handle different DataFusion API versions
"""

import pyarrow as pa
from datafusion import SessionContext, udf
from datafusion import SessionContext, udf, DataFrame

# Print version information for debugging
import datafusion
import pyarrow

print(f"DataFusion version: {datafusion.__version__}")
print(f"PyArrow version: {pyarrow.__version__}")


# Define a user-defined function (UDF)
# Define a user-defined function (UDF) that checks if a value is null
def is_null(array: pa.Array) -> pa.Array:
"""
A UDF that checks if elements in an array are null.
Args:
array (pa.Array): Input PyArrow array
Returns:
pa.Array: Boolean array indicating which elements are null
"""
return array.is_null()


# Create the UDF definition
is_null_arr = udf(
is_null,
[pa.int64()],
pa.bool_(),
"stable",
# This will be the name of the UDF in SQL
# If not specified it will by default the same as Python function name
name="is_null",
is_null, # The Python function to use
[pa.int64()], # Input type(s) - here we expect one int64 column
pa.bool_(), # Output type - returns boolean
"stable", # Volatility - "stable" means same input = same output
name="is_null" # SQL name for the function
)

# Create a context
# Create a DataFusion session context
ctx = SessionContext()

# Create a datafusion DataFrame from a Python dictionary
ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]}, name="t")
# Dataframe:
# +---+---+
# | a | b |
# +---+---+
# | 1 | 4 |
# | 2 | |
# | 3 | 6 |
# +---+---+

# Register UDF for use in SQL
try:
# Method 1: Using DataFrame.from_pydict (for newer DataFusion versions)
print("\nTrying Method 1: DataFrame.from_pydict")
df = DataFrame.from_pydict(ctx, {
"a": [1, 2, 3],
"b": [4, None, 6]
})
df.create_or_replace_table("t")
except Exception as e:
print(f"Method 1 failed: {e}")

try:
# Method 2: Using arrow table directly
print("\nTrying Method 2: Register arrow table")
table = pa.table({
"a": [1, 2, 3],
"b": [4, None, 6]
})
ctx.register_table("t", table)
except Exception as e:
print(f"Method 2 failed: {e}")

# Method 3: Using explicit record batch creation
print("\nTrying Method 3: Explicit record batch creation")
# Define the schema for our data
schema = pa.schema([
('a', pa.int64()), # Column 'a' is int64
('b', pa.int64()) # Column 'b' is int64
])

# Create a record batch with our data
batch = pa.record_batch([
pa.array([1, 2, 3], type=pa.int64()), # Data for column 'a'
pa.array([4, None, 6], type=pa.int64()) # Data for column 'b'
], schema=schema)

# Register the record batch with DataFusion
# Note: The double list [[batch]] is required by the API
ctx.register_record_batches("t", [[batch]])

# Register our UDF with the context
ctx.register_udf(is_null_arr)

# Query the DataFrame using SQL
print("\nExecuting SQL query...")
# Execute a SQL query that uses our UDF
result_df = ctx.sql("select a, is_null(b) as b_is_null from t")
# Dataframe:

# Expected output:
# +---+-----------+
# | a | b_is_null |
# +---+-----------+
# | 1 | false |
# | 2 | true |
# | 3 | false |
# +---+-----------+
assert result_df.to_pydict()["b_is_null"] == [False, True, False]

# Convert result to dictionary and display
result_dict = result_df.to_pydict()
print("\nQuery Results:")
print("Result:", result_dict)

# Verify the results
assert result_dict["b_is_null"] == [False, True, False], "Unexpected results from UDF"
print("\nAssert passed - UDF working as expected!")

# Print a formatted version of the results
print("\nFormatted Results:")
print("+---+-----------+")
print("| a | b_is_null |")
print("+---+-----------+")
for i in range(len(result_dict["a"])):
print(f"| {result_dict['a'][i]} | {str(result_dict['b_is_null'][i]).lower():9} |")
print("+---+-----------+")
18 changes: 16 additions & 2 deletions examples/substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,29 @@
# specific language governing permissions and limitations
# under the License.

import os
from datafusion import SessionContext
from datafusion import substrait as ss

# Get the directory of the current script
script_dir = os.path.dirname(os.path.abspath(__file__))

# Construct the path to the CSV file
# Using os.path.join for cross-platform compatibility
csv_file_path = os.path.join(script_dir, '..', 'testing', 'data', 'csv', 'aggregate_test_100.csv')

# Create a DataFusion context
ctx = SessionContext()

# Register table with context
ctx.register_csv("aggregate_test_data", "./testing/data/csv/aggregate_test_100.csv")
try:
# Register table with context
ctx.register_csv("aggregate_test_data", csv_file_path)
except Exception as e:
print(f"Error registering CSV file: {e}")
print(f"Looking for file at: {csv_file_path}")
raise

# Create Substrait plan from SQL query
substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM aggregate_test_data", ctx)
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>

Expand Down