Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 51 additions & 12 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import mmap
import os
import stat

import sys
import tempfile
Expand Down Expand Up @@ -266,7 +267,7 @@ def from_tensor(
result = result.view(shape)
result = cls(result)
result._handler = handler
result._filename = filename
result.filename = filename
result.index = None
result.parent_shape = shape
if copy_data:
Expand Down Expand Up @@ -321,7 +322,7 @@ def from_storage(

tensor = cls(tensor)
if filename is not None:
tensor._filename = filename
tensor.filename = filename
elif handler is not None:
tensor._handler = handler
if index is not None:
Expand All @@ -339,6 +340,17 @@ def filename(self):
raise RuntimeError("The MemoryMappedTensor has no file associated.")
return filename

@filename.setter
def filename(self, value):
if value is None and self._filename is None:
return
value = str(Path(value).absolute())
if self._filename is not None and value != self._filename:
raise RuntimeError(
"the MemoryMappedTensor has already a filename associated."
)
self._filename = value

@classmethod
def empty_like(cls, input, *, filename=None):
# noqa: D417
Expand Down Expand Up @@ -596,7 +608,7 @@ def empty(cls, *args, **kwargs):
*offsets_strides,
)
result = cls(result)
result._filename = filename
result.filename = filename
return result
return result

Expand Down Expand Up @@ -712,6 +724,8 @@ def from_filename(cls, filename, dtype, shape, index=None):
tensor.

"""
writable = _is_writable(filename)

if isinstance(shape, torch.Tensor):
func_offset_stride = getattr(
torch, "_nested_compute_contiguous_strides_offsets", None
Expand All @@ -724,24 +738,40 @@ def from_filename(cls, filename, dtype, shape, index=None):
"nested tensors. Please upgrade to a more recent "
"version."
)
tensor = torch.from_file(
str(filename), shared=True, dtype=dtype, size=shape.prod(-1).sum().int()
)
if writable:
tensor = torch.from_file(
str(filename),
shared=True,
dtype=dtype,
size=shape.prod(-1).sum().int(),
)
else:
with open(str(filename), "rb") as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
tensor = torch.frombuffer(mm, dtype=dtype)
# mm.close()
tensor = torch._nested_view_from_buffer(
tensor,
shape,
*offsets_strides,
)
else:
shape = torch.Size(shape)
tensor = torch.from_file(
str(filename), shared=True, dtype=dtype, size=shape.numel()
).view(shape)
# whether the file already existed
if writable:
tensor = torch.from_file(
str(filename), shared=True, dtype=dtype, size=shape.numel()
)
else:
with open(str(filename), "rb") as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
tensor = torch.frombuffer(mm, dtype=dtype)
tensor = tensor.view(shape)

if index is not None:
tensor = tensor[index]
out = cls(tensor)
out._filename = filename
out.filename = filename
out._handler = None
out.index = index
out.parent_shape = shape
Expand Down Expand Up @@ -787,7 +817,7 @@ def from_handler(cls, handler, dtype, shape, index=None):
if index is not None:
out = out[index]
out = cls(out)
out._filename = None
out.filename = None
out._handler = handler
out.index = index
out.parent_shape = shape
Expand Down Expand Up @@ -880,7 +910,7 @@ def _index_wrap(self, tensor, item, check=False):
return tensor
tensor = MemoryMappedTensor(tensor)
tensor._handler = getattr(self, "_handler", None)
tensor._filename = getattr(self, "_filename", None)
tensor.filename = getattr(self, "_filename", None)
tensor.index = item
tensor.parent_shape = getattr(self, "parent_shape", None)
return tensor
Expand Down Expand Up @@ -1038,3 +1068,12 @@ def _unbind(tensor, dim):
@implements_for_memmap(torch.chunk)
def _chunk(input, chunks, dim=0):
return input.chunk(chunks, dim=dim)


def _is_writable(file_path):
file_path = str(file_path)
if os.path.exists(file_path):
st = os.stat(file_path)
return bool(st.st_mode & (stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH))
# Assume that the file can be written in the directory
return True
92 changes: 84 additions & 8 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.
import argparse
import gc
import os
import stat
from contextlib import nullcontext
from pathlib import Path

Expand All @@ -12,7 +14,7 @@
from _utils_internal import get_available_devices
from tensordict import TensorDict

from tensordict.memmap import MemoryMappedTensor
from tensordict.memmap import _is_writable, MemoryMappedTensor
from torch import multiprocessing as mp

TIMEOUT = 100
Expand Down Expand Up @@ -157,7 +159,7 @@ def test_zeros(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 0).all()

@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
Expand Down Expand Up @@ -191,7 +193,7 @@ def test_ones(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 1).all()

@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
Expand Down Expand Up @@ -225,7 +227,7 @@ def test_empty(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())

@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
def test_full(self, shape, dtype, device, tmp_path, from_path, shape_arg):
Expand Down Expand Up @@ -258,7 +260,7 @@ def test_full(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 2).all()

def test_zeros_like(self, shape, dtype, device, tmp_path, from_path):
Expand All @@ -272,7 +274,7 @@ def test_zeros_like(self, shape, dtype, device, tmp_path, from_path):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 0).all()

def test_ones_like(self, shape, dtype, device, tmp_path, from_path):
Expand All @@ -286,7 +288,7 @@ def test_ones_like(self, shape, dtype, device, tmp_path, from_path):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 1).all()

def test_full_like(self, shape, dtype, device, tmp_path, from_path):
Expand All @@ -300,7 +302,7 @@ def test_full_like(self, shape, dtype, device, tmp_path, from_path):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 2).all()

def test_from_filename(self, shape, dtype, device, tmp_path, from_path):
Expand Down Expand Up @@ -715,6 +717,80 @@ def test_save_td_with_nested(self, tmpdir):
assert (td[i, j] == tdsave[i, j]).all()


class TestReadWrite:
def test_read_only(self, tmpdir):
tmpdir = Path(tmpdir)
file_path = tmpdir / "elt.mmap"
mmap = MemoryMappedTensor.from_filename(
filename=file_path, shape=[2, 3], dtype=torch.float64
)
mmap.copy_(torch.arange(6).view(2, 3))

file_path = str(file_path.absolute())

assert _is_writable(file_path)
# Modify the permissions field to set the desired permissions
new_permissions = stat.S_IREAD # | stat.S_IWRITE | stat.S_IEXEC

# change permission
os.chmod(file_path, new_permissions)

# Get the current file status
assert not _is_writable(file_path)

del mmap

# load file
mmap = MemoryMappedTensor.from_filename(
filename=file_path, shape=[2, 3], dtype=torch.float64
)
assert (mmap.reshape(-1) == torch.arange(6)).all()

@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
def test_read_only_nested(self, tmpdir):
tmpdir = Path(tmpdir)
file_path = tmpdir / "elt.mmap"
data = MemoryMappedTensor.from_tensor(torch.arange(26), filename=file_path)
mmap = MemoryMappedTensor.from_storage(
data.untyped_storage(),
filename=file_path,
shape=torch.tensor([[2, 3], [4, 5]]),
dtype=data.dtype,
)

file_path = str(file_path.absolute())
assert _is_writable(file_path)

# Modify the permissions field to set the desired permissions
new_permissions = stat.S_IREAD # | stat.S_IWRITE | stat.S_IEXEC

# change permission
os.chmod(file_path, new_permissions)

# Get the current file status
assert not _is_writable(file_path)

# load file
mmap1 = MemoryMappedTensor.from_filename(
filename=file_path, shape=torch.tensor([[2, 3], [4, 5]]), dtype=data.dtype
)
assert (mmap1[0].view(-1) == torch.arange(6)).all()
assert (mmap1[1].view(-1) == torch.arange(6, 26)).all()
# test filename
assert mmap1.filename == mmap.filename
assert mmap1.filename == data.filename
assert mmap1.filename == data.untyped_storage().filename
with pytest.raises(AssertionError):
assert mmap1.untyped_storage().filename == data.untyped_storage().filename

os.chmod(str(file_path), 0o444)
data.fill_(0)
os.chmod(str(file_path), 0o444)

assert (mmap1[0].view(-1) == 0).all()
assert (mmap1[1].view(-1) == 0).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)