Skip to content

Commit

Permalink
Add non in-place version of Batch.to_torch (#1117)
Browse files Browse the repository at this point in the history
Closes: #1116

### API Extensions

- Batch received new method: `to_torch_`. #1117

### Breaking Changes

- The method `to_torch` in `data.utils.batch.Batch` is not in-place
anymore. Instead, a new method `to_torch_` does the conversion in-place.
#1117
  • Loading branch information
dantp-ai committed Apr 17, 2024
1 parent ca4f74f commit 6935a11
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 8 deletions.
4 changes: 2 additions & 2 deletions docs/01_tutorials/03_batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,12 @@ Miscellaneous Notes
.. raw:: html

<details>
<summary>Batch.to_torch and Batch.to_numpy</summary>
<summary>Batch.to_torch_ and Batch.to_numpy_</summary>

::

>>> data = Batch(a=np.zeros((3, 4)))
>>> data.to_torch(dtype=torch.float32, device='cpu')
>>> data.to_torch_(dtype=torch.float32, device='cpu')
>>> print(data.a)
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L1_Batch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@
"source": [
"batch_cat.to_numpy_()\n",
"print(batch_cat)\n",
"batch_cat.to_torch()\n",
"batch_cat.to_torch_()\n",
"print(batch_cat)"
]
},
Expand Down
30 changes: 27 additions & 3 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def test_batch_over_batch_to_torch() -> None:
b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)),
)
batch.b.__dict__["e"] = 1 # bypass the check
batch.to_torch()
batch.to_torch_()
assert isinstance(batch.a, torch.Tensor)
assert isinstance(batch.b.c, torch.Tensor)
assert isinstance(batch.b.d, torch.Tensor)
Expand All @@ -391,7 +391,7 @@ def test_batch_over_batch_to_torch() -> None:
assert batch.b.e.dtype == torch.int32
else:
assert batch.b.e.dtype == torch.int64
batch.to_torch(dtype=torch.float32)
batch.to_torch_(dtype=torch.float32)
assert batch.a.dtype == torch.float32
assert batch.b.c.dtype == torch.float32
assert batch.b.d.dtype == torch.float32
Expand Down Expand Up @@ -477,7 +477,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
batch.to_torch()
batch.to_torch_()
batch.to_numpy_()
a_mem_addr_new = batch.a.__array_interface__["data"][0]
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
Expand Down Expand Up @@ -727,6 +727,30 @@ def test_to_numpy_() -> None:
assert isinstance(batch.c.d, np.ndarray)


class TestToTorch:
"""Tests for `Batch.to_torch()` and its in-place counterpart `Batch.to_torch_()` ."""

@staticmethod
def test_to_torch() -> None:
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
new_batch: Batch = Batch.to_torch(batch)
assert id(batch) != id(new_batch)
assert isinstance(batch.b, np.ndarray)
assert isinstance(batch.c.d, np.ndarray)

assert isinstance(new_batch.b, torch.Tensor)
assert isinstance(new_batch.c.d, torch.Tensor)

@staticmethod
def test_to_torch_() -> None:
batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])})
id_batch = id(batch)
batch.to_torch_()
assert id_batch == id(batch)
assert isinstance(batch.b, torch.Tensor)
assert isinstance(batch.c.d, torch.Tensor)


if __name__ == "__main__":
test_batch()
test_batch_over_batch()
Expand Down
22 changes: 21 additions & 1 deletion tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,16 @@ def to_numpy_(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
...

@staticmethod
def to_torch(
batch: TBatch,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
) -> TBatch:
"""Change all numpy.ndarray to torch.Tensor and return a new Batch."""
...

def to_torch_(
self,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
Expand Down Expand Up @@ -641,7 +650,18 @@ def to_numpy_(self) -> None:
elif isinstance(obj, Batch):
obj.to_numpy_()

@staticmethod
def to_torch(
batch: TBatch,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
) -> TBatch:
new_batch = Batch(batch, copy=True)
new_batch.to_torch_(dtype=dtype, device=device)

return new_batch # type: ignore[return-value]

def to_torch_(
self,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
Expand All @@ -662,7 +682,7 @@ def to_torch(
else:
self.__dict__[batch_key] = obj.to(device)
elif isinstance(obj, Batch):
obj.to_torch(dtype, device)
obj.to_torch_(dtype, device)
else:
# ndarray or scalar
if not isinstance(obj, np.ndarray):
Expand Down
2 changes: 1 addition & 1 deletion tianshou/data/utils/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def to_torch(
return to_torch(np.asanyarray(x), dtype, device)
if isinstance(x, dict | Batch):
x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
x.to_torch(dtype, device)
x.to_torch_(dtype, device)
return x
if isinstance(x, list | tuple):
return to_torch(_parse_value(x), dtype, device)
Expand Down

0 comments on commit 6935a11

Please sign in to comment.