In [None]:
from typing import Any, Hashable, Mapping, Tuple, Sequence


from tensorspecs import _TODONextShape, TensorSpec


import numpy
import numpy.lib.mixins


class TensorView(numpy.lib.mixins.NDArrayOperatorsMixin):
    """
    
    TODO doc

    .. doctest::

        >>> import numpy
        >>> tensor = TensorView(
        ...     numpy.empty((1, 2, 3)), 
        ...     dims=("h", "w", "c"),
        ...     indexes={
        ...         "c": ["r", "g", "b"]
        ...     },
        ... )
    
    """

    # TODO
    # spec: TensorSpec

    # TODO
    def __init__(
        self, 
        tensor: ..., 
        dims: Sequence[Hashable], 
        indexes: Mapping[Hashable, Sequence[Hashable]],
    ):
        self._tensor = tensor
        # TODO
        self._dims = dims
        # TODO FIXME perf: O(1) .index
        self._indexes = indexes

    def __repr__(self):
        return f"{TensorView.__qualname__}({self._tensor!r}, dims={self._dims!r}, indexes={self._indexes!r})"

    @property
    def shape(self):
        return _TODONextShape({
            dim: size 
            for dim, size in 
            zip(self._dims, numpy.shape(self._tensor))
        })
    
    def _get_raw_indexes(self, indexes: Mapping | Tuple | Any):
        match indexes:
            case dict():
                # TODO
                indexes_s = {
                    dim: indexes.get(dim, slice(None)) 
                    for dim in self._dims
                }
            case tuple():
                indexes_s = {
                    dim: index
                    for dim, index in zip(self._dims, indexes)
                }
            case single_index:
                indexes_s = {
                    next(iter(self._dims)): single_index
                }

        # TODO perf: see self._indexes !!!!!
        def _resolve_index(dim: ..., index: ...):
            if dim not in self._indexes:
                return index
            return self._indexes[dim].index(index)

        indexes_raw = []
        for dim, index in indexes_s.items():
            match index:
                case slice():
                    start = index.start
                    if start is not None:
                        start = _resolve_index(dim, start)
                    step = index.step
                    stop = index.stop
                    if stop is not None:
                        stop = _resolve_index(dim, stop)
                    indexes_raw.append(
                        slice(start, stop, step)
                    )
                case _ if isinstance(index, Sequence):
                    indexes_raw.append([
                        _resolve_index(dim, i)
                        for i in index
                    ])
                case _:
                    indexes_raw.append(
                        _resolve_index(dim, index)
                    )

        return tuple(indexes_raw)

    def __getitem__(self, indexes: Mapping | Tuple | Any):
        # TODO ret TensorView
        # TensorView(
        #     self._tensor[self._get_raw_indexes(indexes)],
        #     dims=self._dims,
        #     indexes=self._indexes,
        # )
        return self._tensor[self._get_raw_indexes(indexes)]

    # TODO
    def assign(self, indexes: ..., value: ..., copy: bool = False):
        new_tensor = numpy.array(self._tensor, copy=copy)
        new_tensor[self._get_raw_indexes(indexes)] = value
        return new_tensor
        # TODO
        # return TensorView(
        #     new_tensor, 
        #     dims=self._dims, 
        #     indexes=self._indexes,
        # )

    # TODO helper
    def iter(self, dim: Hashable):
        for index in range(self.shape.sizes[dim]):
            yield self.__getitem__({dim: index})

    def __array__(self, *args, **kwargs):
        return self._tensor.__array__(*args, **kwargs)
        
    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        # TODO
        # Forward ufuncs to the underlying ndarray without wrapping back.
        def unwrap(x):
            return x._tensor if isinstance(x, TensorView) else x

        # unwrap inputs and optional 'out='
        unwrapped_inputs = tuple(unwrap(x) for x in inputs)
        if "out" in kwargs:
            out = kwargs["out"]
            if isinstance(out, tuple):
                kwargs["out"] = tuple(unwrap(x) for x in out)
            else:
                kwargs["out"] = unwrap(out)

        return getattr(ufunc, method)(*unwrapped_inputs, **kwargs)

    def __dlpack__(self, *args, **kwargs):
        return self._tensor.__dlpack__(*args, **kwargs)
    
    def __dlpack_device__(self, *args, **kwargs):
        return self._tensor.__dlpack_device__(*args, **kwargs)



In [None]:

_ = """

tensor = TensorView(...)
tensor[{
    "time": slice(...),
    "step": 1,
    "c": (0, "link_0", 2, 3),
}]
tensor[:, "link1":"link2", ...]


tensor.assign({"link": ("panda_leftfinger", "panda_rightfinger")}, 0.)

"""

In [None]:
a = numpy.empty((10, 9, 8))
dims = ("h", "w", "c")
indexes = {
    "c": ["r", "g", "b"]
}

%timeit -n 100 TensorView(a, dims=dims, indexes=indexes)

1.22 μs ± 90.9 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
# import jax


# tensor = TensorView(
#     jax.numpy.empty((1, 2, 3)), 
#     dims=("h", "w", "c"),
#     indexes={
#         "c": ["r", "g", "b"]
#     },
# )

# tensor + 1
# tensor.assign(0, value=1)[0]
# tensor.assign({"c": "r"}, value=1)#[{"c": "r"}]

# next(tensor.iter(dim="w")).shape
# # (1, 3)