In [0]:

import numpy as np
from numba import njit
from numba.core import types
from numba.extending import (
    typeof_impl,
    type_callable,
    make_attribute_wrapper,
    overload_attribute,
    lower_builtin,
    box,
    unbox,
    NativeValue,
)
from contextlib import ExitStack
from numba.extending import models, register_model
from numba.core import cgutils
from numba import TypingError, from_dtype


# 1. Define the custom NumPy array class
class SimpleArray:
    def __init__(self, data):
        """
        A simple array class that wraps a NumPy array.

        Args:
            data (np.ndarray or list): The data to store in the array.
        """
        self._data = np.asarray(data)  # Ensure it's a NumPy array

    @property
    def data(self):
        return self._data

    @property
    def shape(self):
        return self._data.shape

    @property
    def dtype(self):
        return self._data.dtype

    @property
    def ndim(self):
        return self._data.ndim

    def __repr__(self):
        return f"SimpleArray({self._data.tolist()})"

    def __getitem__(self, index):
        return self._data[index]

    def __setitem__(self, index, value):
        self._data[index] = value

    def sum(self):
        return np.sum(self._data)


# 2. Define the Numba type for the custom array
class SimpleArrayType(types.Type):
    def __init__(self, dtype, ndim):
        """
        Numba type for SimpleArray.

        Args:
            dtype (numba.core.types.Type): The Numba type of the array's elements.
            ndim (int): The number of dimensions of the array.
        """
        self.dtype = dtype
        self.ndim = ndim
        # Important: Use a unique name!  If you re-run this code in the same
        # session, you might get errors if the name isn't changed, because
        # Numba remembers previously defined types.  Using a descriptive
        # name is good practice.
        super().__init__(name=f"SimpleArray[{dtype}, {ndim}D]")


# 3. Register the type with Numba
@typeof_impl.register(SimpleArray)
def typeof_simple_array(val, c):
    """
    Infers the Numba type of a SimpleArray instance.

    Args:
        val (SimpleArray): The SimpleArray instance.
        c:  The typing context (not used here, but required by the signature).

    Returns:
        SimpleArrayType: The inferred Numba type.
    """
    dtype = from_dtype(val.dtype)
    return SimpleArrayType(dtype, val.ndim)


# 4. Create a Numba model for the type (how it's represented in memory)
@register_model(SimpleArrayType)
class SimpleArrayModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [
            ("data", types.Array(fe_type.dtype, fe_type.ndim, "C")),
            ("shape", types.UniTuple(types.intp, fe_type.ndim)),
            ("ndim", types.int32),
        ]
        models.StructModel.__init__(self, dmm, fe_type, members)


# 5. Implement attribute access (.data, .shape, .ndim)
make_attribute_wrapper(SimpleArrayType, "data", "data")
make_attribute_wrapper(SimpleArrayType, "shape", "shape")
make_attribute_wrapper(SimpleArrayType, "ndim", "ndim")


# 6. Implement lowering (converting Python code to LLVM IR)
@lower_builtin(SimpleArray, types.Array)  # Constructor
def impl_simple_array(context, builder, sig, args):
    typ = sig.return_type
    array_obj = cgutils.create_struct_proxy(typ)(context, builder)
    array = args[0]  # numba array
    array_struct = context.make_array(sig.args[0])(context, builder, value=array)

    array_obj.data = array_struct.data
    array_obj.shape = array_struct.shape
    array_obj.ndim = context.get_constant(types.int32, sig.args[0].ndim)  # Get ndim
    return array_obj._getvalue()


# 6b. Implement a direct constructor from a list (for testing/convenience)
@lower_builtin(SimpleArray, types.List)
def impl_simple_array_list(context, builder, sig, args):
    list_arg = args[0]
    dtype = sig.args[0].dtype  # dtype of the list elements
    arr_typ = types.Array(dtype, 1, "C")  # 1D, C-contiguous array
    arr = context.make_array(arr_typ)(context, builder)
    arr.allocate()  # Allocate memory for the array
    arr.size = builder.extract_value(list_arg, 0)  # Set the size
    arr.data = builder.extract_value(list_arg, 1)

    return impl_simple_array(
        context, builder, sig.replace(args=[arr_typ]), [arr.return_value]
    )


# Unboxing
@unbox(SimpleArrayType)
def unbox_simple_array(typ, obj, c):
    """
    Convert a SimpleArray Python object to a Numba native structure.
    """
    is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
    array_obj = cgutils.create_struct_proxy(typ)(c.context, c.builder)

    with ExitStack() as stack:
        data_obj = c.pyapi.object_getattr_string(obj, "_data")
        with cgutils.early_exit_if_null(c.builder, stack, data_obj):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        native_array = c.unbox(types.Array(typ.dtype, typ.ndim, 'A'), data_obj)
        c.pyapi.decref(data_obj)  # Decref after unboxing
        with cgutils.early_exit_if(c.builder, stack, native_array.is_error):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        array_struct = c.context.make_array(types.Array(typ.dtype, typ.ndim, 'A'))(c.context, c.builder, value=native_array.value)
        array_obj.data = array_struct.data
        array_obj.shape = array_struct.shape
        array_obj.ndim = c.context.get_constant(types.int32, typ.ndim)

    return NativeValue(array_obj._getvalue(), is_error=c.builder.load(is_error_ptr))


"""
@box(SimpleArrayType)
def box_simple_array(typ, val, c):
    array_obj = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)

    array_type = types.Array(typ.dtype, typ.ndim, 'A')
    array_struct = c.context.make_array(array_type)(c.context, c.builder)
    array_struct.data = array_obj.data
    array_struct.shape = array_obj.shape
    #Crucially, we need meminfo for boxing to work
    array_struct.meminfo = c.context.nrt.meminfo_data_from_datapointer(c.builder, array_obj.data, array_type)

    np_array = c.box(array_type, array_struct._getvalue())

    with ExitStack() as stack:
        simple_array_cls = c.pyapi.unserialize(c.pyapi.serialize_object(SimpleArray))
        with cgutils.early_exit_if_null(c.builder, stack, simple_array_cls):
            c.pyapi.decref(np_array)  # Decref on error
            return None

        new_simple_array = c.pyapi.call_function_objargs(simple_array_cls, (np_array,))
        c.pyapi.decref(np_array)
        c.pyapi.decref(simple_array_cls)
        with cgutils.early_exit_if_null(c.builder, stack, new_simple_array):
            return None
    return new_simple_array
"""


@lower_builtin("getitem", SimpleArrayType, types.Integer)
def impl_getitem(context, builder, sig, args):
    array_type, index_type = sig.args
    array_obj = cgutils.create_struct_proxy(array_type)(context, builder, value=args[0])
    zero = context.get_constant(types.intp, 0)
    is_negative = builder.icmp_signed("<", args[1], zero)
    with builder.if_then(is_negative, likely=False):
        context.call_conv.return_user_exc(builder, IndexError, ("index out of range",))

    index_value = builder.sext(args[1], types.intp)  # Extend to pointer size if needed.
    ptr = cgutils.get_item_pointer(
        builder, array_type.dtype, array_obj.data, [index_value], wraparound=False
    )
    return context.unpack_value(builder, array_type.dtype, ptr)


@lower_builtin("setitem", SimpleArrayType, types.Integer, types.Any)
def impl_setitem(context, builder, sig, args):
    array_type, index_type, value_type = sig.args
    array_obj = cgutils.create_struct_proxy(array_type)(context, builder, value=args[0])
    zero = context.get_constant(types.intp, 0)
    is_negative = builder.icmp_signed("<", args[1], zero)
    with builder.if_then(is_negative, likely=False):
        context.call_conv.return_user_exc(builder, IndexError, ("index out of range",))
    index_value = builder.sext(args[1], types.intp)  # Extend index
    ptr = cgutils.get_item_pointer(
        builder, array_type.dtype, array_obj.data, [index_value], wraparound=False
    )
    val = context.cast(builder, args[2], value_type, array_type.dtype)  # Cast to dtype
    builder.store(val, ptr)
    return context.get_dummy_value()


# Tell numba that .sum is available
@overload_attribute(SimpleArrayType, "sum")
def get_sum(array):
    def impl():
        return array.sum()

    return impl


@lower_builtin(SimpleArray.sum, SimpleArrayType)
def impl_array_sum(context, builder, sig, args):
    """
    Lowers the .sum() method of SimpleArray.
    """
    array_type = sig.args[0]
    array_obj = cgutils.create_struct_proxy(array_type)(context, builder, value=args[0])

    # Create an accumulator variable of the appropriate type.
    accumulator = context.get_constant(array_type.dtype, 0)

    # Loop through the array and add each element to the accumulator.
    with cgutils.for_range(
        builder,
        builder.mul(
            builder.extract_value(array_obj.shape, 0),
            context.get_constant(types.intp, array_type.ndim - 1),
        ),
    ) as loop:
        index_value = builder.sext(loop.index, types.intp)  # Extend index
        ptr = cgutils.get_item_pointer(
            builder, array_type.dtype, array_obj.data, [index_value], wraparound=False
        )
        element = context.unpack_value(builder, array_type.dtype, ptr)
        accumulator = (
            builder.fadd(accumulator, element)
            if array_type.dtype in types.real_domain
            else builder.add(accumulator, element)
        )

    return accumulator


# 7. Example Usage (within a Numba-jitted function)
@njit
def sum_simple_array(arr):
    total = 0
    for i in range(arr.shape[0]):
        total += arr[i]
    return total


@njit
def modify_array(arr):
    arr[0] = 100
    return arr


@njit
def use_sum_method(arr):
    return arr.sum()

In [0]:
# Create a SimpleArray from a list
data_list = [1.0, 2.0, 3.0, 4.0]
arr_from_list = SimpleArray(data_list)  # pass
print(f"From list: {arr_from_list}")  # pass
print(arr_from_list.dtype)  # pass
print(type(arr_from_list.data))  # pass


@njit
def get_data(arr):
    return arr.data

import dis

dis.dis(SimpleArray)