Skip to content

Commit

Permalink
Simplified layer code, improved tests
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Aug 20, 2018
1 parent adbff2b commit 77c5ebf
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 78 deletions.
15 changes: 9 additions & 6 deletions anndata/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Main class and helper functions.
"""
import os, sys
import warnings
import os
import sys
import logging as logg
from enum import Enum
from collections import OrderedDict
from functools import reduce
from pathlib import Path
from textwrap import indent, dedent
from typing import Union, Optional, Any, Iterable, Mapping, Sequence, Sized, Tuple, List
from copy import deepcopy
Expand All @@ -26,6 +25,7 @@
from zarr.core import Array as ZarrArray
except ImportError:
class ZarrArray:
@staticmethod
def __rep__():
return 'mock zarr.core.Array'

Expand Down Expand Up @@ -68,6 +68,9 @@ def classes(cls):
return tuple(c.value for c in cls.__members__.values())


Index = Union[slice, int, np.int64, np.ndarray, Sized]


class BoundRecArr(np.recarray):
"""A `np.recarray` to which fields can be added using `.['key']`.
Expand Down Expand Up @@ -652,7 +655,7 @@ def __init__(
filename: Optional[PathLike] = None,
filemode: Optional[str] = None,
asview: bool = False,
*, oidx=None, vidx=None):
*, oidx: Index = None, vidx: Index = None):
if asview:
if not isinstance(X, AnnData):
raise ValueError('`X` has to be an AnnData object.')
Expand All @@ -665,8 +668,8 @@ def __init__(
dtype=dtype, shape=shape,
filename=filename, filemode=filemode)

def _init_as_view(self, adata_ref, oidx, vidx):
def get_n_items_idx(idx):
def _init_as_view(self, adata_ref: 'AnnData', oidx: Index, vidx: Index):
def get_n_items_idx(idx: Index):
if isinstance(idx, np.ndarray) and idx.dtype == bool:
return idx.sum()
else:
Expand Down
37 changes: 20 additions & 17 deletions anndata/layers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
#thats just start
from typing import Mapping, Optional, Union

import numpy as np
from collections import OrderedDict
from scipy.sparse import issparse

class AnnDataLayers():
if False:
from .base import AnnData, Index # noqa

def __init__(self, adata, layers=None, dtype='float32', adata_ref=None, oidx=None, vidx=None):

class AnnDataLayers:
def __init__(
self,
adata: 'AnnData',
layers: Optional[Mapping[str, np.ndarray]] = None,
dtype: Union[str, np.dtype] = 'float32',
adata_ref: 'AnnData' = None,
oidx: 'Index' = None,
vidx: 'Index' = None,
):
self._adata = adata
self._adata_ref = adata_ref
self._oidx = oidx
Expand Down Expand Up @@ -63,29 +74,21 @@ def __delattr__(self, key):
def keys(self):
if self.isview:
return self._adata_ref.layers.keys()
else:
else: # TODO @Koncopd: Why wrap this in list() and not the above?
return list(self._layers.keys())

def items(self, copy=True):
if self.isview:
if copy:
return [(k, v[self._oidx, self._vidx].copy()) for (k, v) in self._adata_ref.layers.items()]
else:
return [(k, v[self._oidx, self._vidx]) for (k, v) in self._adata_ref.layers.items()]
pairs = [(k, v[self._oidx, self._vidx]) for (k, v) in self._adata_ref.layers.items()]
else:
if copy:
return [(k, v.copy()) for (k, v) in self._layers.items()]
else:
return self._layers.items()
pairs = self._layers.items()
return [(k, v.copy()) for (k, v) in pairs] if copy else pairs

def as_dict(self, copy=True):
return {k:v for (k, v) in self.items(copy)}
return dict(self.items(copy))

def __len__(self):
if self.isview:
return len(self._adata_ref.layers)
else:
return len(self._layers)
return len(self._adata_ref.layers if self.isview else self._layers)

@property
def isview(self):
Expand Down
111 changes: 56 additions & 55 deletions anndata/tests/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,59 +39,60 @@
# -------------------------------------------------------------------------------


def test_readwrite_h5ad():
for typ in [np.array, csr_matrix]:
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
assert pd.api.types.is_string_dtype(adata.obs['oanno1'])
adata.raw = adata
adata.write('./test.h5ad')
adata = ad.read('./test.h5ad')
assert pd.api.types.is_categorical(adata.obs['oanno1'])
assert pd.api.types.is_string_dtype(adata.obs['oanno2'])
assert adata.obs.index.tolist() == ['name1', 'name2', 'name3']
assert adata.obs['oanno1'].cat.categories.tolist() == ['cat1', 'cat2']
assert pd.api.types.is_categorical(adata.raw.var['vanno2'])


def test_readwrite_dynamic():
for typ in [np.array, csr_matrix]:
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
adata.filename = './test.h5ad' # change to backed mode
adata.write()
adata = ad.read('./test.h5ad')
assert pd.api.types.is_categorical(adata.obs['oanno1'])
assert pd.api.types.is_string_dtype(adata.obs['oanno2'])
assert adata.obs.index.tolist() == ['name1', 'name2', 'name3']
assert adata.obs['oanno1'].cat.categories.tolist() == ['cat1', 'cat2']


def test_readwrite_zarr():
for typ in [np.array, csr_matrix]:
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
assert pd.api.types.is_string_dtype(adata.obs['oanno1'])
adata.write_zarr('./test_zarr_dir', chunks=True)
adata = ad.read_zarr('./test_zarr_dir')
assert pd.api.types.is_categorical(adata.obs['oanno1'])
assert pd.api.types.is_string_dtype(adata.obs['oanno2'])
assert adata.obs.index.tolist() == ['name1', 'name2', 'name3']
assert adata.obs['oanno1'].cat.categories.tolist() == ['cat1', 'cat2']
@pytest.mark.parametrize('typ', [np.array, csr_matrix])
def test_readwrite_h5ad(typ):
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
assert pd.api.types.is_string_dtype(adata.obs['oanno1'])
adata.raw = adata
adata.write('./test.h5ad')
adata = ad.read('./test.h5ad')
assert pd.api.types.is_categorical(adata.obs['oanno1'])
assert pd.api.types.is_string_dtype(adata.obs['oanno2'])
assert adata.obs.index.tolist() == ['name1', 'name2', 'name3']
assert adata.obs['oanno1'].cat.categories.tolist() == ['cat1', 'cat2']
assert pd.api.types.is_categorical(adata.raw.var['vanno2'])


@pytest.mark.parametrize('typ', [np.array, csr_matrix])
def test_readwrite_dynamic(typ):
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
adata.filename = './test.h5ad' # change to backed mode
adata.write()
adata = ad.read('./test.h5ad')
assert pd.api.types.is_categorical(adata.obs['oanno1'])
assert pd.api.types.is_string_dtype(adata.obs['oanno2'])
assert adata.obs.index.tolist() == ['name1', 'name2', 'name3']
assert adata.obs['oanno1'].cat.categories.tolist() == ['cat1', 'cat2']


@pytest.mark.skipif(not find_spec('zarr'), reason='Zarr is not installed')
@pytest.mark.parametrize('typ', [np.array, csr_matrix])
def test_readwrite_zarr(typ):
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
assert pd.api.types.is_string_dtype(adata.obs['oanno1'])
adata.write_zarr('./test_zarr_dir', chunks=True)
adata = ad.read_zarr('./test_zarr_dir')
assert pd.api.types.is_categorical(adata.obs['oanno1'])
assert pd.api.types.is_string_dtype(adata.obs['oanno2'])
assert adata.obs.index.tolist() == ['name1', 'name2', 'name3']
assert adata.obs['oanno1'].cat.categories.tolist() == ['cat1', 'cat2']


@pytest.mark.skipif(not find_spec('loompy'), reason='Loompy is not installed (expected on Python 3.5)')
def test_readwrite_loom():
for i, typ in enumerate([np.array, csr_matrix]):
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
adata.write_loom('./test.loom')
adata = ad.read_loom('./test.loom', sparse=(i == 1))
if isinstance(X, np.ndarray):
assert np.allclose(adata.X, X)
else:
# TODO: this should not be necessary
assert np.allclose(adata.X.toarray(), X.toarray())
@pytest.mark.parametrize('typ', [np.array, csr_matrix])
def test_readwrite_loom(typ):
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
adata.write_loom('./test.loom')
adata = ad.read_loom('./test.loom', sparse=typ is csr_matrix)
if isinstance(X, np.ndarray):
assert np.allclose(adata.X, X)
else:
# TODO: this should not be necessary
assert np.allclose(adata.X.toarray(), X.toarray())


def test_read_csv():
Expand All @@ -116,8 +117,8 @@ def test_read_tsv_iter():
assert adata.X.tolist() == X_list


def test_write_csv():
for typ in [np.array, csr_matrix]:
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
adata.write_csvs('./test_csv_dir', skip_data=False)
@pytest.mark.parametrize('typ', [np.array, csr_matrix])
def test_write_csv(typ):
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
adata.write_csvs('./test_csv_dir', skip_data=False)

0 comments on commit 77c5ebf

Please sign in to comment.