Skip to content

Commit

Permalink
feat(python): Support pytorch Tensor and Dataset export with new `t…
Browse files Browse the repository at this point in the history
…o_torch` DataFrame/Series method
  • Loading branch information
alexander-beedie committed Apr 29, 2024
1 parent 2e28176 commit e7315e4
Show file tree
Hide file tree
Showing 12 changed files with 692 additions and 36 deletions.
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/dataframe/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ Export DataFrame data to other formats:
DataFrame.to_numpy
DataFrame.to_pandas
DataFrame.to_struct
DataFrame.to_torch
11 changes: 11 additions & 0 deletions py-polars/docs/source/reference/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,14 @@ Connect to pyarrow datasets.
:toctree: api/

scan_pyarrow_dataset

PyTorch Tensors / Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: api/

DataFrame.to_torch

.. currentmodule:: polars.utils

~torch.PolarsDataset
177 changes: 176 additions & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,13 @@
N_INFER_DEFAULT,
Boolean,
Float64,
Int32,
Int64,
Object,
String,
UInt16,
UInt32,
UInt64,
)
from polars.dependencies import (
_HVPLOT_AVAILABLE,
Expand Down Expand Up @@ -108,7 +113,7 @@
)
from polars.selectors import _expand_selector_dicts, _expand_selectors
from polars.slice import PolarsSlice
from polars.type_aliases import DbWriteMode
from polars.type_aliases import DbWriteMode, TorchExportType

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import dtype_str_repr as _dtype_str_repr
Expand All @@ -121,6 +126,7 @@
from typing import Literal

import deltalake
import torch
from hvplot.plotting.core import hvPlotTabularPolars
from xlsxwriter import Workbook

Expand Down Expand Up @@ -164,6 +170,7 @@
UniqueKeepStrategy,
UnstackDirection,
)
from polars.io.torch import PolarsDataset

if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec, TypeAlias
Expand Down Expand Up @@ -1612,6 +1619,174 @@ def raise_on_copy(msg: str) -> None:

return out

@overload
def to_torch(
self,
return_type: Literal["tensor"] = ...,
*,
label: str | Expr | Sequence[str | Expr] | None = ...,
dtype: PolarsDataType | None = ...,
) -> torch.Tensor: ...

@overload
def to_torch(
self,
return_type: Literal["dataset"],
*,
label: str | Expr | Sequence[str | Expr] | None = ...,
dtype: PolarsDataType | None = ...,
) -> PolarsDataset: ...

@overload
def to_torch(
self,
return_type: Literal["dict"],
*,
label: str | Expr | Sequence[str | Expr] | None = ...,
dtype: PolarsDataType | None = ...,
) -> dict[str, torch.Tensor]: ...

def to_torch(
self,
return_type: TorchExportType = "tensor",
*,
label: str | Expr | Sequence[str | Expr] | None = None,
dtype: PolarsDataType | None = None,
) -> torch.Tensor | dict[str, torch.Tensor] | PolarsDataset:
"""
Convert DataFrame to a 2D PyTorch tensor, Dataset, or dict of Tensors.
Parameters
----------
return_type : {"tensor", "dataset", "dict"}
Set return type; a 2D PyTorch tensor, PolarsDataset (which is a drop-in
compatible frame-specialized TensorDataset), or dict of Tensors.
label
One or more column names or expressions that label the feature data; when
`return_type` is "dataset", the PolarsDataset returns `(features, label)`
tensor tuples for each row. Otherwise, it returns `(features,)` tensor
tuples where the feature contains all the row data. This parameter is a
no-op for the other return-types.
dtype
Unify the dtype of all returned tensors; this casts any frame Series
that are not of the required dtype before converting to tensor. This
includes the label column *unless* the label is an expression (such
as `pl.col("label_column").cast(pl.Int16)`).
Notes
-----
The convenience :class:`PolarsDataset` class returned by `return_type="dataset"`
implements flexible row retrieval as `(features, label)` tensors, for easy
integration with PyTorch `DataLoader` objects, and is drop-in compatible
with `TensorDataset` (from which it inherits).
Examples
--------
>>> df = pl.DataFrame(
... {
... "lbl": [0, 1, 2, 3],
... "feat1": [1, 0, 0, 1],
... "feat2": [1.5, -0.5, 0.0, -2.25],
... }
... )
Standard return type (Tensor), with f32 supertype:
>>> df.to_torch(dtype=pl.Float32)
tensor([[ 0.0000, 1.0000, 1.5000],
[ 1.0000, 0.0000, -0.5000],
[ 2.0000, 0.0000, 0.0000],
[ 3.0000, 1.0000, -2.2500]])
As a dictionary of individual Tensors:
>>> df.to_torch("dict")
{'lbl': tensor([0, 1, 2, 3]),
'feat1': tensor([1, 0, 0, 1]),
'feat2': tensor([ 1.5000, -0.5000, 0.0000, -2.2500], dtype=torch.float64)}
As a PolarsDataset, with f64 supertype:
>>> ds = df.to_torch("dataset", dtype=pl.Float64)
>>> ds[3]
(tensor([ 3.0000, 1.0000, -2.2500], dtype=torch.float64),)
>>> ds[:2]
(tensor([[ 0.0000, 1.0000, 1.5000],
[ 1.0000, 0.0000, -0.5000]], dtype=torch.float64),)
>>> ds[[0, 3]]
(tensor([[ 0.0000, 1.0000, 1.5000],
[ 3.0000, 1.0000, -2.2500]], dtype=torch.float64),)
As a convenience the PolarsDataset can opt-in to half-precision data:
>>> list(ds.half())
[(tensor([0.0000, 1.0000, 1.5000], dtype=torch.float16),),
(tensor([ 1.0000, 0.0000, -0.5000], dtype=torch.float16),),
(tensor([2., 0., 0.], dtype=torch.float16),),
(tensor([ 3.0000, 1.0000, -2.2500], dtype=torch.float16),)]
Pass PolarsDataset to a DataLoader, designating the labels column:
>>> from torch.utils.data import DataLoader
>>> ds = df.to_torch("dataset", label="lbl")
>>> dl = DataLoader(ds, batch_size=2)
>>> batches = list(dl)
>>> batches[0]
[tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000]], dtype=torch.float64), tensor([0, 1])]
Note that labels can be given as expressions, allowing them to have a
dtype independent of the feature columns (multi-column labels are also
supported).
>>> ds = df.to_torch(
... "dataset",
... dtype=pl.Float32,
... label=pl.col("lbl").cast(pl.Int16),
... )
>>> ds[:2]
(tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000]]), tensor([0, 1], dtype=torch.int16))
Easily integrate with (for example) scikit-learn and other datasets:
>>> from sklearn.datasets import fetch_california_housing # doctest: +SKIP
>>> housing = fetch_california_housing() # doctest: +SKIP
>>> df = pl.DataFrame(
... data=housing.data,
... schema=housing.feature_names,
... ).with_columns(
... Target=housing.target,
... ) # doctest: +SKIP
>>> train = df.to_torch("dataset", label="Target") # doctest: +SKIP
>>> loader = DataLoader(
... train,
... shuffle=True,
... batch_size=64,
... ) # doctest: +SKIP
"""
torch = import_optional("torch")

if dtype in (UInt16, UInt32, UInt64):
msg = f"PyTorch does not support u16, u32, or u64 dtypes; given {dtype}"
raise ValueError(msg)
else:
to_dtype = dtype or {UInt16: Int32, UInt32: Int64, UInt64: Int64}
frame = self.cast(to_dtype) # type: ignore[arg-type]

if return_type == "tensor":
return torch.from_numpy(frame.to_numpy(writable=True, use_pyarrow=False))
elif return_type == "dict":
return {srs.name: srs.to_torch() for srs in frame}
elif return_type == "dataset":
from polars.io.torch import PolarsDataset

return PolarsDataset(frame, label=label)
else:
valid_torch_types = ", ".join(get_args(TorchExportType))
msg = f"invalid `return_type`: {return_type!r}\nExpected one of: {valid_torch_types}"
raise ValueError(msg)

def to_pandas(
self,
*,
Expand Down

0 comments on commit e7315e4

Please sign in to comment.