Skip to content

Commit

Permalink
Construct pylibcudf columns from objects supporting `__cuda_array_i…
Browse files Browse the repository at this point in the history
…nterface__` (#15615)

This PR allows zero copy construction of `pylibcudf` columns from device arrays via the `gpumemoryview` class. cc @mroeschke

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #15615
  • Loading branch information
brandon-b-miller committed May 2, 2024
1 parent 6882870 commit 4494991
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 32 deletions.
107 changes: 107 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/column.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ from .scalar cimport Scalar
from .types cimport DataType, type_id
from .utils cimport int_to_bitmask_ptr, int_to_void_ptr

import functools

import numpy as np


cdef class Column:
"""A container of nullable device data as a column of elements.
Expand Down Expand Up @@ -223,6 +227,51 @@ cdef class Column:
c_result = move(make_column_from_scalar(dereference(c_scalar), size))
return Column.from_libcudf(move(c_result))

@staticmethod
def from_cuda_array_interface_obj(object obj):
"""Create a Column from an object with a CUDA array interface.
Parameters
----------
obj : object
The object with the CUDA array interface to create a column from.
Returns
-------
Column
A Column containing the data from the CUDA array interface.
Notes
-----
Data is not copied when creating the column. The caller is
responsible for ensuring the data is not mutated unexpectedly while the
column is in use.
"""
data = gpumemoryview(obj)
iface = data.__cuda_array_interface__()
if iface.get('mask') is not None:
raise ValueError("mask not yet supported.")

typestr = iface['typestr'][1:]
if not is_c_contiguous(
iface['shape'],
iface['strides'],
np.dtype(typestr).itemsize
):
raise ValueError("Data must be C-contiguous")

data_type = _datatype_from_dtype_desc(typestr)
size = iface['shape'][0]
return Column(
data_type,
size,
data,
None,
0,
0,
[]
)

cpdef DataType type(self):
"""The type of data in the column."""
return self._data_type
Expand Down Expand Up @@ -296,3 +345,61 @@ cdef class ListColumnView:
cpdef offsets(self):
"""The offsets column of the underlying list column."""
return self._column.child(1)


@functools.cache
def _datatype_from_dtype_desc(desc):
mapping = {
'u1': type_id.UINT8,
'u2': type_id.UINT16,
'u4': type_id.UINT32,
'u8': type_id.UINT64,
'i1': type_id.INT8,
'i2': type_id.INT16,
'i4': type_id.INT32,
'i8': type_id.INT64,
'f4': type_id.FLOAT32,
'f8': type_id.FLOAT64,
'b1': type_id.BOOL8,
'M8[s]': type_id.TIMESTAMP_SECONDS,
'M8[ms]': type_id.TIMESTAMP_MILLISECONDS,
'M8[us]': type_id.TIMESTAMP_MICROSECONDS,
'M8[ns]': type_id.TIMESTAMP_NANOSECONDS,
'm8[s]': type_id.DURATION_SECONDS,
'm8[ms]': type_id.DURATION_MILLISECONDS,
'm8[us]': type_id.DURATION_MICROSECONDS,
'm8[ns]': type_id.DURATION_NANOSECONDS,
}
if desc not in mapping:
raise ValueError(f"Unsupported dtype: {desc}")
return DataType(mapping[desc])


def is_c_contiguous(
shape: Sequence[int], strides: Sequence[int], itemsize: int
) -> bool:
"""Determine if shape and strides are C-contiguous

Parameters
----------
shape : Sequence[int]
Number of elements in each dimension.
strides : Sequence[int]
The stride of each dimension in bytes.
itemsize : int
Size of an element in bytes.

Return
------
bool
The boolean answer.
"""

if any(dim == 0 for dim in shape):
return True
cumulative_stride = itemsize
for dim, stride in zip(reversed(shape), reversed(strides)):
if dim > 1 and stride != cumulative_stride:
return False
cumulative_stride *= dim
return True
36 changes: 4 additions & 32 deletions python/cudf/cudf/core/buffer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pickle
import weakref
from types import SimpleNamespace
from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple
from typing import Any, Dict, Literal, Mapping, Optional, Tuple

import numpy
from typing_extensions import Self
Expand Down Expand Up @@ -480,36 +480,6 @@ def __str__(self) -> str:
)


def is_c_contiguous(
shape: Sequence[int], strides: Sequence[int], itemsize: int
) -> bool:
"""Determine if shape and strides are C-contiguous
Parameters
----------
shape : Sequence[int]
Number of elements in each dimension.
strides : Sequence[int]
The stride of each dimension in bytes.
itemsize : int
Size of an element in bytes.
Return
------
bool
The boolean answer.
"""

if any(dim == 0 for dim in shape):
return True
cumulative_stride = itemsize
for dim, stride in zip(reversed(shape), reversed(strides)):
if dim > 1 and stride != cumulative_stride:
return False
cumulative_stride *= dim
return True


def get_ptr_and_size(array_interface: Mapping) -> Tuple[int, int]:
"""Retrieve the pointer and size from an array interface.
Expand All @@ -531,7 +501,9 @@ def get_ptr_and_size(array_interface: Mapping) -> Tuple[int, int]:
shape = array_interface["shape"] or (1,)
strides = array_interface["strides"]
itemsize = cudf.dtype(array_interface["typestr"]).itemsize
if strides is None or is_c_contiguous(shape, strides, itemsize):
if strides is None or cudf._lib.pylibcudf.column.is_c_contiguous(
shape, strides, itemsize
):
nelem = math.prod(shape)
ptr = array_interface["data"][0] or 0
return ptr, nelem * itemsize
Expand Down
51 changes: 51 additions & 0 deletions python/cudf/cudf/pylibcudf_tests/test_column_from_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import pyarrow as pa
import pytest
from utils import assert_column_eq

import cudf
from cudf._lib import pylibcudf as plc

VALID_TYPES = [
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64(),
pa.uint8(),
pa.uint16(),
pa.uint32(),
pa.uint64(),
pa.float32(),
pa.float64(),
pa.bool_(),
pa.timestamp("s"),
pa.timestamp("ms"),
pa.timestamp("us"),
pa.timestamp("ns"),
pa.duration("s"),
pa.duration("ms"),
pa.duration("us"),
pa.duration("ns"),
]


@pytest.fixture(params=VALID_TYPES, ids=repr)
def valid_type(request):
return request.param


@pytest.fixture
def valid_column(valid_type):
if valid_type == pa.bool_():
return pa.array([True, False, True], type=valid_type)
return pa.array([1, 2, 3], type=valid_type)


def test_from_cuda_array_interface(valid_column):
col = plc.column.Column.from_cuda_array_interface_obj(
cudf.Series(valid_column)
)
expect = valid_column

assert_column_eq(col, expect)

0 comments on commit 4494991

Please sign in to comment.