Skip to content

Commit

Permalink
Use unpack_index from our code
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed May 21, 2019
1 parent 289957d commit 9ce9d03
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
11 changes: 5 additions & 6 deletions anndata/base.py
Expand Up @@ -18,7 +18,6 @@
from pandas.api.types import is_string_dtype, is_categorical
from scipy import sparse
from scipy.sparse import issparse
from scipy.sparse.sputils import IndexMixin
from natsort import natsorted

# try importing zarr
Expand All @@ -43,7 +42,7 @@ def __rep__():
from .layers import AnnDataLayers

from . import utils
from .utils import Index, get_n_items_idx
from .utils import Index, get_n_items_idx, unpack_index
from .logging import anndata_logger as logger
from .compat import PathLike

Expand Down Expand Up @@ -444,7 +443,7 @@ class DataFrameView(_ViewMixin, pd.DataFrame):
_metadata = ['_view_args']


class Raw(IndexMixin):
class Raw:
def __init__(
self,
adata: Optional['AnnData'] = None,
Expand Down Expand Up @@ -531,7 +530,7 @@ def _normalize_indices(self, packed_index):
packed_index = packed_index[0], packed_index[1].values
if isinstance(packed_index[0], pd.Series):
packed_index = packed_index[0].values, packed_index[1]
obs, var = super()._unpack_index(packed_index)
obs, var = unpack_index(packed_index)
obs = _normalize_index(obs, self._adata.obs_names)
var = _normalize_index(var, self.var_names)
return obs, var
Expand All @@ -551,7 +550,7 @@ def __init__(self, n_dims):
super().__init__(msg)


class AnnData(IndexMixin, metaclass=utils.DeprecationMixinMeta):
class AnnData(metaclass=utils.DeprecationMixinMeta):
"""An annotated data matrix.
:class:`~anndata.AnnData` stores a data matrix :attr:`X` together with annotations
Expand Down Expand Up @@ -1298,7 +1297,7 @@ def _normalize_indices(self, index: Optional[Index]):
# Needs to be refactored once we support a tuple of two arbitrary index types
if any(isinstance(i, np.ndarray) and i.dtype == bool for i in index):
return index
obs, var = super()._unpack_index(index)
obs, var = unpack_index(index)
obs = _normalize_index(obs, self.obs_names)
var = _normalize_index(var, self.var_names)
return obs, var
Expand Down
8 changes: 4 additions & 4 deletions anndata/h5py/h5sparse.py
Expand Up @@ -7,8 +7,8 @@
import h5py
import numpy as np
import scipy.sparse as ss
from scipy.sparse.sputils import IndexMixin

from ..utils import unpack_index
from ..compat import PathLike

from .utils import _chunked_rows
Expand Down Expand Up @@ -236,7 +236,7 @@ def _zero_many(self, i, j):
_cs_matrix._zero_many = _zero_many


class SparseDataset(IndexMixin):
class SparseDataset:
"""Analogous to :class:`h5py.Dataset <h5py:Dataset>`, but for sparse matrices.
"""

Expand All @@ -255,7 +255,7 @@ def format_str(self):

def __getitem__(self, index):
if index == (): index = slice(None)
row, col = self._unpack_index(index)
row, col = unpack_index(index)
format_class = get_format_class(self.format_str)
mock_matrix = format_class(self.shape, dtype=self.dtype)
mock_matrix.data = self.h5py_group['data']
Expand All @@ -265,7 +265,7 @@ def __getitem__(self, index):

def __setitem__(self, index, value):
if index == (): index = slice(None)
row, col = self._unpack_index(index)
row, col = unpack_index(index)
format_class = get_format_class(self.format_str)
mock_matrix = format_class(self.shape, dtype=self.dtype)
mock_matrix.data = self.h5py_group['data']
Expand Down
23 changes: 21 additions & 2 deletions anndata/utils.py
@@ -1,9 +1,10 @@
import warnings
from functools import wraps
from typing import Mapping, Any, Sequence, Union, Sized, Optional
from typing import Mapping, Any, Sequence, Union, Tuple

import pandas as pd
import numpy as np
from scipy.sparse import spmatrix

from .logging import get_logger
if False:
Expand Down Expand Up @@ -132,7 +133,7 @@ def is_deprecated(attr):
]


Index = Union[slice, int, np.int64, np.ndarray, Sized]
Index = Union[slice, int, np.int64, np.ndarray, spmatrix]


def get_n_items_idx(idx: Index, l: int):
Expand All @@ -147,3 +148,21 @@ def get_n_items_idx(idx: Index, l: int):
return 1
else:
return len(idx)


def unpack_index(index: Union[Index, Tuple[Index, Index]]) -> Tuple[Index, Index]:
# handle indexing with boolean matrices
if (
isinstance(index, (spmatrix, np.ndarray))
and index.ndim == 2
and index.dtype.kind == 'b'
): return index.nonzero()

if not isinstance(index, tuple):
return index, slice(None)
elif len(index) == 2:
return index
elif len(index) == 1:
return index[0], slice(None)
else:
raise IndexError('invalid number of indices')

0 comments on commit 9ce9d03

Please sign in to comment.