Skip to content

[IR] Implement __buffer__ for tensors #2241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@

from __future__ import annotations

import abc
import contextlib
import dataclasses
import heapq
@@ -40,7 +39,7 @@

import ml_dtypes
import numpy as np
from typing_extensions import TypeIs
from typing_extensions import Buffer, TypeIs

import onnxscript
from onnxscript.ir import (
@@ -95,7 +94,7 @@
return hasattr(obj, "__dlpack__")


class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
class TensorBase(Buffer, _protocols.TensorProtocol, _display.PrettyPrintable):
"""Convenience Shared methods for classes implementing TensorProtocol."""

__slots__ = ()
@@ -111,6 +110,13 @@
"""
return f"{self.__class__.__name__}<{self._printable_type_shape()}>"

def __buffer__(self, flags: int, /) -> memoryview:
"""Return a memoryview of the tensor.

This is used to support the buffer protocol.
"""
return self.tobytes().__buffer__(flags)

@property
def size(self) -> int:
"""The number of elements in the tensor."""
@@ -408,6 +414,29 @@
def __repr__(self) -> str:
return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"

def __buffer__(self, flags: int, /) -> memoryview:
"""Return a memoryview of the tensor.

This is used to support the buffer protocol.
"""
if self.dtype in {
_enums.DataType.INT4,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
}:
# Packing is required. So we call tobytes() directly
return self.tobytes().__buffer__(flags)

# Otherwise get the memoryview from the numpy array
array = self.numpy()
if not array.data.c_contiguous:
array = np.ascontiguousarray(array)
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
if not _IS_LITTLE_ENDIAN:
# Need to copy because we are returning the underlying data directly
array = array.view(array.dtype.newbyteorder("<")).copy()
return array.__buffer__(flags)

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"ndarray[Any, Any]" has no attribute "__buffer__" To disable, use # type: ignore[attr-defined]

@property
def dtype(self) -> _enums.DataType:
"""The data type of the tensor. Immutable."""
@@ -657,6 +686,19 @@
assert self._array is not None
return self._array.__array__(dtype)

def __buffer__(self, flags: int, /) -> memoryview:
"""Return a memoryview of the tensor.

This is used to support the buffer protocol.
"""
self._check_validity()
if self.raw is None:
self._load()
assert self.raw is not None
offset = self._offset or 0
length = self._length or self.nbytes
return memoryview(self.raw)[offset : offset + length]

def __dlpack__(self, *, stream: Any = None) -> Any:
raise NotImplementedError(
"ExternalTensor does not support DLPack because it uses memory mapping. "
@@ -953,6 +995,13 @@
def __repr__(self) -> str:
return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})"

def __buffer__(self, flags: int, /) -> memoryview:
"""Return a memoryview of the tensor.

This is used to support the buffer protocol.
"""
return self._evaluate().__buffer__(flags)

@property
def raw(self) -> Callable[[], _protocols.TensorProtocol]:
return self._func
4 changes: 4 additions & 0 deletions onnxscript/ir/_protocols.py
Original file line number Diff line number Diff line change
@@ -133,6 +133,10 @@ def __array__(self, dtype: Any = None) -> np.ndarray:
"""Return the tensor as a numpy array, compatible with np.array."""
...

def __buffer__(self, flags: int, /) -> memoryview:
"""Return a view of the tensor data."""
...

def __dlpack__(self, *, stream: Any = ...) -> Any:
"""Return PyCapsule."""
...
8 changes: 4 additions & 4 deletions onnxscript/ir/external_data.py
Original file line number Diff line number Diff line change
@@ -173,14 +173,14 @@ def _write_external_data(
for tensor, tensor_info in zip(tensors, external_data_infos, strict=True):
current_offset = tensor_info.offset
assert tensor is not None
raw_data = tensor.tobytes()
if isinstance(tensor, _core.ExternalTensor):
tensor.release()
# Pad file to required offset if needed
file_size = data_file.tell()
if current_offset > file_size:
data_file.write(b"\0" * (current_offset - file_size))
data_file.write(raw_data)
with memoryview(tensor) as view:
data_file.write(view)
if isinstance(tensor, _core.ExternalTensor):
tensor.release()


def _create_external_tensor(
31 changes: 27 additions & 4 deletions onnxscript/ir/tensor_adapters.py
Original file line number Diff line number Diff line change
@@ -79,25 +79,48 @@
def numpy(self) -> npt.NDArray:
import torch

self.raw: torch.Tensor
# Calling .contiguous() is usually less costly than calling it on numpy arrays
# so we do it first for users assuming a contiguous array is needed for most usages
torch_tensor: torch.Tensor = self.raw
if not torch_tensor.is_contiguous():
torch_tensor = torch_tensor.contiguous()
if self.dtype == ir.DataType.BFLOAT16:
return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
return torch_tensor.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
if self.dtype in {
ir.DataType.FLOAT8E4M3FN,
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E5M2FNUZ,
}:
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
return torch_tensor.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())

return self.raw.numpy(force=True)
return torch_tensor.numpy(force=True)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
del copy # Unused, but needed for the signature
if dtype is None:
return self.numpy()
return self.numpy().__array__(dtype)

def __buffer__(self, flags: int, /) -> memoryview:
"""Return a memoryview of the tensor.

This is used to support the buffer protocol.
"""
if self.dtype in {
ir.DataType.INT4,
ir.DataType.UINT4,
ir.DataType.FLOAT4E2M1,
}:
# Packing is required. So we call tobytes() directly
return self.tobytes().__buffer__(flags)

# Otherwise get the memoryview from the numpy array
array = self.numpy()
assert array.data.c_contiguous, "Bug: The array should be contiguous"
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
return array.__buffer__(flags)

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"ndarray[Any, dtype[Any]]" has no attribute "__buffer__" To disable, use # type: ignore[attr-defined]

def tobytes(self) -> bytes:
# Implement tobytes to support native PyTorch types so we can use types like bloat16
# Reading from memory directly is also more efficient because
Loading
Oops, something went wrong.