Skip to content

Commit

Permalink
added additional slice case
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Aug 18, 2023
1 parent 3ac4ef8 commit 28936d6
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion cola/ops/operator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,13 @@ def __getitem__(self, ids: Union[Tuple[int, ...], Tuple[slice, ...]])->Union[Arr
# print(type(ids[0]), type(ids[1]))
# check if first element is ellipsis
xnp = self.xnp
from cola.ops import Sliced
match ids:
case int(i):
ei = xnp.canonical(loc=i, shape=(self.shape[-1], ), dtype=self.dtype, device=self.device)
return (self.T @ ei)
case (slice() | xnp.ndarray() | np.ndarray()) as s_i:
return Sliced(A=self, slices=(s_i, slice(None)))
case b, int(j):
ej = xnp.canonical(loc=j, shape=(self.shape[-1], ), dtype=self.dtype, device=self.device)
return (self @ ej)[b]
Expand All @@ -210,7 +213,6 @@ def __getitem__(self, ids: Union[Tuple[int, ...], Tuple[slice, ...]])->Union[Arr
return (self.T @ ei)[b]
case (slice() | xnp.ndarray() | np.ndarray()) as s_i, \
(slice() | xnp.ndarray() | np.ndarray()) as s_j:
from cola.ops import Sliced
return Sliced(A=self, slices=(s_i, s_j))
case list(li), list(lj):
out = []
Expand Down

0 comments on commit 28936d6

Please sign in to comment.