## User-defined functions (`map_elements` and `map_batches`)
By the end of the lecture you will be able to:
- apply user-defined functions elementwise
- apply numpy ufuncs to an expression
- apply user-defined functions batchwise


In [None]:
import polars as pl
import numpy as np

In [None]:
csv_file = "../data/titanic.csv"
df = pl.read_csv(csv_file)

### Element-wise functions
We start by creating a simple function that takes a value in and returns the square of the value

In [None]:
def square(x):
    return x ** 2

We apply this function to the full `DataFrame` on the `Age` column to create a new column called `age_squared`

In [None]:
(
    df.with_columns(
        age_squared = pl.col("Age").map_elements(square)
    )
    .select("Age","age_squared")
    .head()
)

We see that the operation worked but Polars outputted a warning.

#### Why does Polars output a warning?
Polars outputs a warning that using `map_elements` is probably a lot slower than an equivalent expression. This is because when we use `map_elements` Polars goes through the rows one-by-one applying the function in the (slower) Python layer.

If we instead use the equivalent Polars expression Polars:
- runs the function in the (faster) Rust layer and
- runs the function once for all the rows together

**Always try to use Polars expressions instead of `map_elements`**

In some cases Polars tries to identify an expression to replace your function

#### Lambda function
In the example above we defined the function in Python using a standard `def` function definition. We can also pass an inline lambda function to apply

In [None]:
(
    df.with_columns(
        age_squared = pl.col("Age").map_elements(lambda x: x**2)
    )
    .select("Age","age_squared")
    .head()
)

#### Caching repeated elements
If the column you are applying the function to has:
- many repeated elements and
- has a function that is expensive to compute

then you can use the `lru_cache` function from Python's built-in `functools` module. The `lru_cache` stores previous inputs and outputs to the function so that Polars can just look up previously calculated values instead of re-calculating the output again

In [None]:
from functools import lru_cache

@lru_cache
def square(x):
    return x**2

LRU stands for "Least Recently Used," indicating that this cache discards the least recently used items first when the cache reaches its capacity limit

#### Running in parallel
By default the function is run element-by-element on a single thread (i.e. not in parallel). We can instruct Polars to run the function multi-threaded (i.e. in parallel) with the `strategy` argument

In [None]:
(
    df
    .with_columns(
        age_squared = pl.col("Age").map_elements(lambda x: x**2,strategy="threading")
    )
    .select("Age","age_squared")
    .head()
)

Running with multiple threads in parallel is not guaranteed to run faster - indeed the extra overhead may even make it slower. The treading approach is more likely to help when the amount of work in each call of the function is relatively large.

#### Lazy mode
We can use `map_elements` in lazy mode. Indeed, we can even use `map_elements` in streaming mode for larger-than-memory datasets as the entire query is in the `--- STREAMING` block of the query plan with `.explain(streaming=True)`

In [None]:
print(
    df
    .lazy()
    .with_columns(
        age_squared = pl.col("Age").map_elements(lambda x: x**2)
    )
    .select("Age","age_squared")
    .explain(streaming=True)
)

We can use `map_elements` in streaming mode because the function is applied to each element individually and so the function still works fine when the data is processed in batches in streaming.

#### Multiple columns
We can apply the same function to multiple columns in the same way that we would do with any other expression

In [None]:
(
    df
    .with_columns(
        pl.col("Age","Fare").map_elements(lambda x: x**2).name.suffix("_squared")
    )
    .select("Age","Age_squared","Fare","Fare_squared")
    .head()
)

If we have a more complicated function that requires interaction between the elements in different columns then we combine the columns needed into a `pl.Struct` and apply the function to that.

In the next section of the course we learn more about struct columns. For now all we need to know is that inside our user-defined function the values from different columns are in a Python `dict` with the column names as keys.

In this example we create a function to add the values in the `Age` and `Fare` columns

In [None]:
def sum_age_fare(struct:dict)->float:
    # We check if both values are floats
    if isinstance(struct["Age"],float) and isinstance(struct["Fare"],float):
        # If they are we add them
        return struct["Age"] + struct["Fare"]
    else:
        # If there is a null value return a null
        return None

We now apply the function to the `pl.Struct` expression created from the `Age` and `Fare` columns

In [None]:
(
    df
    .with_columns(
        age_fare_summed = pl.struct(pl.col("Age"),pl.col("Fare")).map_elements(sum_age_fare)
    )
    .select("Age","Fare","age_fare_summed")
    .head(6)
)

### Numpy ufuncs
ufuncs are common Numpy functions that work element-by-element such as `np.cos` or `np.exp`. We do not need to use `map_elements` if we are trying to apply Numpy ufuncs.

Instead we just pass the Polars expression to the Numpy function where we would normally pass a Numpy array. This should be just as fast as working with a Numpy array.

In this example we again square the values in the `Age` column

In [None]:
(
    df
    .with_columns(
        age_squared = np.power(pl.col("Age"),2)
    )
    .select("Age","age_squared")
    .head()    
)

### Applying functions to a `Series`
With `map_elements` the function works one-row at a time.

In other cases we want to apply user-defined functions that work on an entire `Series` at once. For this case we use `map_batches`.

In this example we want to normalise the `Age` values by substracting the mean and dividing by the standard deviation. We cannot do this row-by-row with `map_elements` as we need the mean and standard deviation of the whole column first.

In [None]:
def normalize_column(s):
    # We use s as the variable to remind us the input here is a Series
    mean = s.mean()
    std = s.std()
    return (s - mean) / std

We now apply this function with `map_batches`

In [None]:
(
    df
    .with_columns(
        normalised_age = pl.col("Age").map_batches(normalize_column)
    )
    .select("Age","normalised_age")
    .head()
)

The output of the `map_batches` function must be a `Series` that is the same length as the other columns in the `DataFrame`. This means that we cannot, for example, return just the mean of a column using `map_batches`

### Lazy mode
We can use `map_batches` in lazy mode. However, we cannot use `map_batches` in streaming mode. This is because the function would be applied over the different batches independently. In our example we would be calculating the mean and stanard deviation for each batch in streaming rather than the whole column

In [None]:
print(
    df
    .lazy()
    .with_columns(
        normalised_age = pl.col("Age").map_batches(normalize_column)
    )
    .select("Age","normalised_age")
    .explain(streaming=True)
)

We see here that the `map_batches` component comes after the streaming part of the query. If we know that our user-defined function is suitable for running batchwise in streaming mode we can set the `is_elementwise` argument to `True`

In [None]:
print(
    df
    .lazy()
    .with_columns(
        normalised_age = pl.col("Age").map_batches(normalize_column,is_elementwise=True)
    )
    .select("Age","normalised_age")
    .explain(streaming=True)
)

Be sure you understand your function before doing this, however!

### Functions in a groupby context
The behaviour of the `map_elements` and `map_batches` functions may be different from what you expect when applied in an `agg` after `group_by`. See the `agg` lecture in the `groupby` section for more details.

## Exercises
In the exercises you will develop your understanding of:
- applying elementwise functions with `map_elements`
- applying batched functions with `map_batches`

### Exercise 1
Get the length of the names in the `Name` column using a Python method

In [None]:
(
    pl.read_csv(csv_file)
    .select(
        "Name",
        <blank>
    )
    .head()
)

Make a function called `categorize_age` to categorize passengers based on their age with:
- "Child" with age less than 18
- "Adult" with age between 18 and 60 and
- "Senior with age above 60

Apply this function to the `Age` column

Do this again using a more optimal `when.then` approach. Note that we must pass `pl.lit("Child")` so Polars knows it is a value rather than a column name

Create a function called `find_unaccompanied_adults` to identify passengers who:
- are under 18 years of age and
- are travelling with no parents or children (`Parch`) or siblings

For these passengers return `'Unaccompanied child'` while for other passengers return `'Other'`

Apply this function to the `DataFrame` and filter for the unaccompanied adults

### Exercise 2
A Numpy array has an `argmax` method to return the index of the largest element in the array. Note that `argmax` is not a Numpy ufunc as it does not work elementwise.

Return a one-row `DataFrame` with the `argmax` of all the floating-point columns in the `DataFrame`.

In [None]:
(
    df
    .select(
        <blank>
    )
)

## Solutions
### Solution to exercise 1
Get the length of the names in the `Name` column using a Python method

In [None]:
(
    pl.read_csv(csv_file)
    .select(
        "Name",
        pl.col("Name").map_elements(lambda x: len(x)).alias("len_name")
    )
    .head()
)

Make a function to categorize passengers based on their age with:
- "Child" with age less than 18
- "Adult" with age between 18 and 60 and
- "Senior with age above 60

In [None]:
def categorize_age(age):
    if age < 18:
        return 'Child'
    elif age <= 60:
        return 'Adult'
    else:
        return 'Senior'

Apply this function to the `Age` column

In [None]:
(
    df
    .select(
        "Age",
        pl.col("Age").map_elements(categorize_age).alias("age_category")
    )
    .head()
)

Do this again using a more optimal `when.then` approach. Note that we must pass `pl.lit("Child")` so Polars knows it is a value rather than a column name

In [None]:
(
    df
    .select(
        "Age",
        pl.when(pl.col("Age")<18).then(pl.lit("Child")).when(pl.col("Age")<=60).then(pl.lit("Adult")).otherwise(pl.lit("Senior")).alias("age_category")
    )
    .head()
)

Create a function called `find_unaccompanied_adults` to identify passengers who:
- are under 18 years of age and
- are travelling with no parents or children (`Parch`) or siblings

For these passengers return `'Unaccompanied child'` while for other passengers return `'Other'`

In [None]:
def find_unaccompanied_adults(struct):
    if isinstance(struct["Age"],float) and struct["Age"] < 18 and struct["Parch"] == 0 and struct["SibSp"] == 0:
        return 'Unaccompanied adult'
    else:
        return 'Other'

Apply this function to the `DataFrame` and filter for the unaccompanied adults

In [None]:
(
    df
    .with_columns(
        pl.struct("Age","Parch","SibSp").map_elements(find_unaccompanied_adults).alias("unaccompanied_adults")
    )
    .filter(unaccompanied_adults = "Unaccompanied adult")
    .head()
)

### Solution to exercise 2
A Numpy array has an `argmax` method to return the index of the largest element in the array. Note that `argmax` is not a Numpy ufunc as it does not work elementwise.

Return a one-row `DataFrame` with the `argmax` of the floating-point columns in the `DataFrame`.

In [None]:
(
    df
    .select(
        pl.col(pl.FLOAT_DTYPES).map_batches(lambda x:x.to_numpy().argmax())
    )
)