Skip to content

Commit

Permalink
feat(python): Clarify to_torch "features" and "label" parameter beh…
Browse files Browse the repository at this point in the history
…aviour when return type is not "dataset" (#16218)
  • Loading branch information
alexander-beedie committed May 14, 2024
1 parent 674a048 commit 3b21311
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
21 changes: 13 additions & 8 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,15 +1669,16 @@ def to_torch(
Set return type; a 2D PyTorch tensor, PolarsDataset (a frame-specialized
TensorDataset), or dict of Tensors.
label
One or more column names or expressions that label the feature data; when
`return_type` is "dataset", the PolarsDataset returns `(features, label)`
tensor tuples for each row. Otherwise, it returns `(features,)` tensor
tuples where the feature contains all the row data. This parameter is a
no-op for the other return-types.
One or more column names, expressions, or selectors that label the feature
data; when `return_type` is "dataset", the PolarsDataset will return
`(features, label)` tensor tuples for each row. Otherwise, it returns
`(features,)` tensor tuples where the feature contains all the row data;
note that setting this parameter with any other result type will raise an
informative error.
features
One or more column names or expressions that contain the feature data; if
omitted, all columns that are not designated as part of the label are used.
This parameter is a no-op for return-types other than "dataset".
One or more column names, expressions, or selectors that contain the feature
data; if omitted, all columns that are not designated as part of the label
are used. This parameter is a no-op for return-types other than "dataset".
dtype
Unify the dtype of all returned tensors; this casts any frame Series
that are not of the required dtype before converting to tensor. This
Expand Down Expand Up @@ -1770,6 +1771,10 @@ def to_torch(
... batch_size=64,
... ) # doctest: +SKIP
"""
if return_type != "dataset" and (label is not None or features is not None):
msg = "the `label` and `features` parameters can only be set when `return_type='dataset'`"
raise ValueError(msg)

torch = import_optional("torch")

if dtype in (UInt16, UInt32, UInt64):
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/dataframe/test_to_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,15 @@ def test_misc_errors(self, df: pl.DataFrame) -> None:
match="tensors used as indices must be long, int",
):
_res2 = ds[torch.tensor([0, 3], dtype=torch.complex64)]

with pytest.raises(
ValueError,
match="`label` and `features` parameters .* when `return_type='dataset'`",
):
_res3 = df.to_torch(label="stroopwafel")

with pytest.raises(
ValueError,
match="`label` and `features` parameters .* when `return_type='dataset'`",
):
_res4 = df.to_torch("dict", features=cs.float())

0 comments on commit 3b21311

Please sign in to comment.