Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Polars materializer #2229

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/zenml/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from zenml.integrations.neural_prophet import NeuralProphetIntegration # noqa
from zenml.integrations.openai import OpenAIIntegration # noqa
from zenml.integrations.pillow import PillowIntegration # noqa
from zenml.integrations.polars import PolarsIntegration
from zenml.integrations.pycaret import PyCaretIntegration # noqa
from zenml.integrations.pytorch import PytorchIntegration # noqa
from zenml.integrations.pytorch_lightning import ( # noqa
Expand Down
1 change: 1 addition & 0 deletions src/zenml/integrations/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
OPEN_AI = "openai"
PILLOW = "pillow"
PLOTLY = "plotly"
POLARS = "polars"
PYCARET = "pycaret"
PYTORCH = "pytorch"
PYTORCH_L = "pytorch_lightning"
Expand Down
35 changes: 35 additions & 0 deletions src/zenml/integrations/polars/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Initialization of the Polars integration."""

from zenml.integrations.constants import POLARS
from zenml.integrations.integration import Integration


class PolarsIntegration(Integration):
"""Definition of Polars integration for ZenML."""

NAME = POLARS
REQUIREMENTS = [
"polars>=0.19.5",
"pyarrow>=12.0.0"
]

@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.polars import materializers # noqa


PolarsIntegration.check_installation()
18 changes: 18 additions & 0 deletions src/zenml/integrations/polars/materializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Initialization for the Polars materializers."""

from zenml.integrations.polars.materializers.dataframe_materializer import ( # noqa
PolarsMaterializer,
)
121 changes: 121 additions & 0 deletions src/zenml/integrations/polars/materializers/dataframe_materializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Polars materializer."""

import os
import tempfile
from typing import Any, ClassVar, Tuple, Type, Union

import polars as pl
import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore

from zenml.enums import ArtifactType
from zenml.io import fileio
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.utils import io_utils


class PolarsMaterializer(BaseMaterializer):
"""Materializer to read/write Polars dataframes."""

ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (
pl.DataFrame,
pl.Series,
)
ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA

def load(self, data_type: Type[Any]) -> Any:
"""Reads and returns Polars data after copying it to temporary path.

Args:
data_type: The type of the data to read.

Returns:
A Polars data frame or series.
"""
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()

# Copy from artifact store to temporary directory
io_utils.copy_dir(self.uri, temp_dir.name)

# Load the data from the temporary directory
table = pq.read_table(
os.path.join(temp_dir.name, "dataframe.parquet").replace("\\", "/")
)

# If the data is of type pl.Series, convert it back to a pyarrow array
# instead of a table.
if (
table.schema.metadata
and b"zenml_is_pl_series" in table.schema.metadata
):
isinstance_bytes = table.schema.metadata[b"zenml_is_pl_series"]
isinstance_series = bool.from_bytes(isinstance_bytes, "big")
if isinstance_series:
table = table.column(0)

# Convert the table to a Polars data frame or series
data = pl.from_arrow(table)

# Cleanup and return
fileio.rmtree(temp_dir.name)

return data

def save(self, data: Union[pl.DataFrame, pl.Series]) -> None:
"""Writes Polars data to the artifact store.

Args:
data: The data to write.

Raises:
TypeError: If the data is not of type pl.DataFrame or pl.Series.
"""
# Data type check
if not isinstance(data, self.ASSOCIATED_TYPES):
raise TypeError(
f"Expected data of type {self.ASSOCIATED_TYPES}, "
f"got {type(data)}"
)

# Convert the data to an Apache Arrow Table
if isinstance(data, pl.DataFrame):
table = data.to_arrow()
else:
# Construct a PyArrow Table with schema from the individual pl.Series
# array if it is a single pl.Series.
array = data.to_arrow()
table = pa.Table.from_arrays([array], names=[data.name])

# Register whether data is of type pl.Series, so that the materializer read step can
# convert it back appropriately.
isinstance_bytes = isinstance(data, pl.Series).to_bytes(1, "big")
table = table.replace_schema_metadata(
{b"zenml_is_pl_series": isinstance_bytes}
)

# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()

# Write the table to a Parquet file
path = os.path.join(temp_dir.name, "dataframe.parquet").replace(
"\\", "/"
)
pq.write_table(table, path) # Uses lz4 compression by default
io_utils.copy_dir(temp_dir.name, self.uri)

# Remove the temporary directory
fileio.rmtree(temp_dir.name)
13 changes: 13 additions & 0 deletions tests/integration/integrations/polars/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
13 changes: 13 additions & 0 deletions tests/integration/integrations/polars/materializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.


import polars
from polars import testing as polars_testing

from tests.unit.test_general import _test_materializer
from zenml.integrations.polars.materializers.dataframe_materializer import (
PolarsMaterializer,
)


def test_polars_materializer():
"""Test the polars materializer."""
dataframe = polars.DataFrame([0, 1, 2, 3], schema=["column_test"])
series = polars.Series([0, 1, 2, 3])

for type_, example in [
(polars.DataFrame, dataframe),
(polars.Series, series),
]:
result = _test_materializer(
step_output_type=type_,
materializer_class=PolarsMaterializer,
step_output=example,
assert_visualization_exists=False,
)

# Use different assertion given type, since Polars implements
# these differently.
if type_ == polars.DataFrame:
polars_testing.assert_frame_equal(example, result)
else:
polars_testing.assert_series_equal(example, result)
Loading