diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index f461513e8b3cf..8bfc31edc747d 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2801,6 +2801,8 @@ cdef extern from "arrow/extension_type.h" namespace "arrow": cdef cppclass CExtensionType" arrow::ExtensionType"(CDataType): c_string extension_name() shared_ptr[CDataType] storage_type() + int byte_width() + int bit_width() @staticmethod shared_ptr[CArray] WrapArray(shared_ptr[CDataType] ext_type, diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index fe38bf651baae..9863d96058947 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -251,14 +251,14 @@ def test_ext_type_repr(): assert repr(ty) == "IntegerType(DataType(int64))" -def test_ext_type__lifetime(): +def test_ext_type_lifetime(): ty = UuidType() wr = weakref.ref(ty) del ty assert wr() is None -def test_ext_type__storage_type(): +def test_ext_type_storage_type(): ty = UuidType() assert ty.storage_type == pa.binary(16) assert ty.__class__ is UuidType @@ -267,6 +267,32 @@ def test_ext_type__storage_type(): assert ty.__class__ is ParamExtType +def test_ext_type_byte_width(): + # Test for fixed-size binary types + ty = UuidType() + assert ty.byte_width == 16 + ty = ParamExtType(5) + assert ty.byte_width == 5 + + # Test for non fixed-size binary types + ty = LabelType() + with pytest.raises(ValueError, match="Non-fixed width type"): + _ = ty.byte_width + + +def test_ext_type_bit_width(): + # Test for fixed-size binary types + ty = UuidType() + assert ty.bit_width == 128 + ty = ParamExtType(5) + assert ty.bit_width == 40 + + # Test for non fixed-size binary types + ty = LabelType() + with pytest.raises(ValueError, match="Non-fixed width type"): + _ = ty.bit_width + + def test_ext_type_as_py(): ty = UuidType() expected = uuid4() diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 480f19c81dfb9..5113df36557f4 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1519,6 +1519,24 @@ cdef class BaseExtensionType(DataType): """ return pyarrow_wrap_data_type(self.ext_type.storage_type()) + @property + def byte_width(self): + """ + The byte width of the extension type. + """ + if self.ext_type.byte_width() == -1: + raise ValueError("Non-fixed width type") + return self.ext_type.byte_width() + + @property + def bit_width(self): + """ + The bit width of the extension type. + """ + if self.ext_type.bit_width() == -1: + raise ValueError("Non-fixed width type") + return self.ext_type.bit_width() + def wrap_array(self, storage): """ Wrap the given storage array as an extension array.