Skip to content

Commit

Permalink
remove unsused categories upon slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
falexwolf committed Feb 9, 2018
1 parent d954103 commit 8cabf9c
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 55 deletions.
54 changes: 48 additions & 6 deletions anndata/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
import os, sys
import warnings
import logging
import logging as logg
from enum import Enum
from collections import Mapping, Sequence, Sized
import numpy as np
Expand All @@ -13,13 +13,14 @@
from scipy.sparse import issparse
from scipy.sparse.sputils import IndexMixin
from textwrap import dedent
from natsort import natsorted

from . import h5py
from . import utils

# FORMAT = '%(levelname)s: %(message)s' # TODO: add a better formatter
FORMAT = '%(message)s'
logging.basicConfig(format=FORMAT, level=logging.INFO, stream=sys.stdout)
logg.basicConfig(format=FORMAT, level=logg.INFO, stream=sys.stdout)

_MAIN_NARRATIVE = """\
:class:`~anndata.AnnData` stores a data matrix ``.X`` together with
Expand Down Expand Up @@ -640,23 +641,23 @@ def _init_as_view(self, adata_ref, oidx, vidx):
oidx_normalized, vidx_normalized = oidx, vidx
if isinstance(oidx, (int, np.int64)): oidx_normalized = slice(oidx, oidx+1, 1)
if isinstance(vidx, (int, np.int64)): vidx_normalized = slice(vidx, vidx+1, 1)
self._obs = DataFrameView(adata_ref.obs.iloc[oidx_normalized], view_args=(self, 'obs'))
self._var = DataFrameView(adata_ref.var.iloc[vidx_normalized], view_args=(self, 'var'))
obs_sub = adata_ref.obs.iloc[oidx_normalized]
var_sub = adata_ref.var.iloc[vidx_normalized]
self._obsm = ArrayView(adata_ref.obsm[oidx_normalized], view_args=(self, 'obsm'))
self._varm = ArrayView(adata_ref.varm[vidx_normalized], view_args=(self, 'varm'))
# hackish solution here, no copy should be necessary
uns_new = self._adata_ref._uns.copy()
# fix _n_obs, _n_vars
if isinstance(oidx, slice):
self._n_obs = len(self._obs.index)
self._n_obs = len(obs_sub.index)
elif isinstance(oidx, (int, np.int64)):
self._n_obs = 1
elif isinstance(oidx, Sized):
self._n_obs = len(oidx)
else:
raise KeyError('Unknown Index type')
if isinstance(vidx, slice):
self._n_vars = len(self._var.index)
self._n_vars = len(var_sub.index)
elif isinstance(vidx, (int, np.int64)):
self._n_vars = 1
elif isinstance(vidx, Sized):
Expand All @@ -665,6 +666,12 @@ def _init_as_view(self, adata_ref, oidx, vidx):
raise KeyError('Unknown Index type')
# need to do the slicing after setting self._n_obs, self._n_vars
self._slice_uns_sparse_matrices_inplace(uns_new, self._oidx)
# fix categories
self._remove_unused_categories(adata_ref.obs, obs_sub, uns_new)
self._remove_unused_categories(adata_ref.var, var_sub, uns_new)
# set attributes
self._obs = DataFrameView(obs_sub, view_args=(self, 'obs'))
self._var = DataFrameView(var_sub, view_args=(self, 'var'))
self._uns = DictView(uns_new, view_args=(self, 'uns'))
# set data
if self.isbacked: self._X = None
Expand Down Expand Up @@ -1128,8 +1135,10 @@ def _getitem_copy(self, index):
if not self.isbacked: X = self._X[oidx, vidx]
else: X = self.file['X'][oidx, vidx]
obs_new = self._obs.iloc[oidx]
self._remove_unused_categories(self._obs, obs_new, self._uns)
obsm_new = self._obsm[oidx]
var_new = self._var.iloc[vidx]
self._remove_unused_categories(self._var, var_new, self._uns)
varm_new = self._varm[vidx]
assert obs_new.shape[0] == X.shape[0], (oidx, obs_new)
assert var_new.shape[0] == X.shape[1], (vidx, var_new)
Expand All @@ -1138,6 +1147,39 @@ def _getitem_copy(self, index):
raw_new = None if self.raw is None else self.raw[oidx]
return AnnData(X, obs_new, var_new, uns_new, obsm_new, varm_new, raw=raw_new)

def _remove_unused_categories(self, df_full, df_sub, uns):
from pandas.api.types import is_categorical
for k in df_full:
if is_categorical(df_full[k]):
all_categories = df_full[k].cat.categories
df_sub[k].cat.remove_unused_categories(inplace=True)
# also correct the colors...
if k + '_colors' in uns:
uns[k + '_colors'] = uns[
k + '_colors'][
np.where(np.in1d(
all_categories, df_sub[k].cat.categories))[0]]

def _sanitize(self):
"""Transform string arrays to categorical data types, if they store less
categories than the total number of samples.
"""
from pandas.api.types import is_string_dtype
for ann in ['obs', 'var']:
for key in getattr(self, ann).columns:
df = getattr(self, ann)
if is_string_dtype(df[key]):
c = pd.Categorical(
df[key], categories=natsorted(np.unique(df[key])))
if len(c.categories) < len(c):
df[key] = c
df[key].cat.categories = df[key].cat.categories.astype('U')
logg.info(
'... storing {} as categorical type'.format(key))
logg.info(
' access categories as adata.{}[\'{}\'].cat.categories'
.format(ann, key))

def _slice_uns_sparse_matrices_inplace(self, uns, oidx):
# slice sparse spatrices of n_obs × n_obs in self.uns
if not (isinstance(oidx, slice) and
Expand Down
96 changes: 49 additions & 47 deletions anndata/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_creation():
AnnData(sp.eye(2))
AnnData(
np.array([[1, 2, 3], [4, 5, 6]]),
dict(Smp=['A', 'B']),
dict(Obs=['A', 'B']),
dict(Feat=['a', 'b', 'c']))

assert AnnData(np.array([1, 2])).X.shape == (2,)
Expand All @@ -27,10 +27,10 @@ def test_creation():
def test_names():
adata = AnnData(
np.array([[1, 2, 3], [4, 5, 6]]),
dict(smp_names=['A', 'B']),
dict(obs_names=['A', 'B']),
dict(var_names=['a', 'b', 'c']))

assert adata.smp_names.tolist() == 'A B'.split()
assert adata.obs_names.tolist() == 'A B'.split()
assert adata.var_names.tolist() == 'a b c'.split()

adata = AnnData(np.array([[1, 2], [3, 4], [5, 6]]),
Expand All @@ -41,10 +41,10 @@ def test_names():
def test_indices_dtypes():
adata = AnnData(
np.array([[1, 2, 3], [4, 5, 6]]),
dict(smp_names=['A', 'B']),
dict(obs_names=['A', 'B']),
dict(var_names=['a', 'b', 'c']))
adata.smp_names = ['ö', 'a']
assert adata.smp_names.tolist() == ['ö', 'a']
adata.obs_names = ['ö', 'a']
assert adata.obs_names.tolist() == ['ö', 'a']


def test_creation_from_vector():
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_slicing():
def test_slicing_strings():
adata = AnnData(
np.array([[1, 2, 3], [4, 5, 6]]),
dict(smp_names=['A', 'B']),
dict(obs_names=['A', 'B']),
dict(var_names=['a', 'b', 'c']))

assert adata['A', 'a'].X.tolist() == 1
Expand All @@ -93,7 +93,7 @@ def test_slicing_strings():
def test_slicing_series():
adata = AnnData(
np.array([[1, 2], [3, 4], [5, 6]]),
dict(smp_names=['A', 'B', 'C']),
dict(obs_names=['A', 'B', 'C']),
dict(var_names=['a', 'b']))
df = pd.DataFrame({'a': ['1', '2', '2']})
df1 = pd.DataFrame({'b': ['1', '2']})
Expand All @@ -102,119 +102,121 @@ def test_slicing_series():
assert (adata[:, df1['b'].values == '2'].X.tolist()
== adata[:, df1['b'] == '2'].X.tolist())

def test_slicing_remove_unused_categories():
adata = AnnData(
np.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
dict(k=['a', 'a', 'b', 'b']))
print(adata)
adata._sanitize()
print(adata[3:5])
assert adata[3:5].obs['k'].cat.categories.tolist() == ['b']
print(adata)
quit()

def test_get_subset_annotation():
adata = AnnData(np.array([[1, 2, 3], [4, 5, 6]]),
dict(S=['A', 'B']),
dict(F=['a', 'b', 'c']))
assert adata[0, 0].smp['S'].tolist() == ['A']

assert adata[0, 0].obs['S'].tolist() == ['A']
assert adata[0, 0].var['F'].tolist() == ['a']


def test_transpose():
adata = AnnData(
np.array([[1, 2, 3], [4, 5, 6]]),
dict(smp_names=['A', 'B']),
dict(obs_names=['A', 'B']),
dict(var_names=['a', 'b', 'c']))

adata1 = adata.T

# make sure to not modify the original!
assert adata.smp_names.tolist() == ['A', 'B']
assert adata.obs_names.tolist() == ['A', 'B']
assert adata.var_names.tolist() == ['a', 'b', 'c']

assert adata1.smp_names.tolist() == ['a', 'b', 'c']
assert adata1.obs_names.tolist() == ['a', 'b', 'c']
assert adata1.var_names.tolist() == ['A', 'B']
assert adata1.X.shape == adata.X.T.shape

adata2 = adata.transpose()
assert np.array_equal(adata1.X, adata2.X)
assert np.array_equal(adata1.smp, adata2.smp)
assert np.array_equal(adata1.obs, adata2.obs)
assert np.array_equal(adata1.var, adata2.var)


def test_append_col():
adata = AnnData(np.array([[1, 2, 3], [4, 5, 6]]))

adata.smp['new'] = [1, 2]
adata.obs['new'] = [1, 2]
# this worked in the initial AnnData, but not with a dataframe
# adata.smp[['new2', 'new3']] = [['A', 'B'], ['c', 'd']]
# adata.obs[['new2', 'new3']] = [['A', 'B'], ['c', 'd']]

from pytest import raises
with raises(ValueError):
adata.smp['new4'] = 'far too long'.split()
adata.obs['new4'] = 'far too long'.split()


def test_set_smp():
def test_set_obs():
adata = AnnData(np.array([[1, 2, 3], [4, 5, 6]]))

adata.smp = pd.DataFrame({'a': [3, 4]})
assert adata.smp_names.tolist() == [0, 1]
adata.obs = pd.DataFrame({'a': [3, 4]})
assert adata.obs_names.tolist() == [0, 1]

from pytest import raises
with raises(ValueError):
adata.smp = pd.DataFrame({'a': [3, 4, 5]})
adata.smp = {'a': [1, 2]}


# def test_print():
# adata = AnnData(np.array([[1, 2, 3], [4, 5, 6]]),
# dict(foo=['A', 'B']),
# dict(bar=['a', 'b', 'c']))
# print(adata)
# print('>>> print(adata.smp)')
# print(adata.smp)
adata.obs = pd.DataFrame({'a': [3, 4, 5]})
adata.obs = {'a': [1, 2]}


def test_multicol():
adata = AnnData(np.array([[1, 2, 3], [4, 5, 6]]))
# 'c' keeps the columns as should be
adata.smpm['c'] = np.array([[0., 1.], [2, 3]])
assert adata.smpm_keys() == ['c']
assert adata.smpm['c'].tolist() == [[0., 1.], [2, 3]]
adata.obsm['c'] = np.array([[0., 1.], [2, 3]])
assert adata.obsm_keys() == ['c']
assert adata.obsm['c'].tolist() == [[0., 1.], [2, 3]]


def test_n_smps():
def test_n_obs():
adata = AnnData(np.array([[1, 2], [3, 4], [5, 6]]))
assert adata.n_smps == 3
assert adata.n_obs == 3
adata1 = adata[:2, ]
assert adata1.n_smps == 2
assert adata1.n_obs == 2


def test_concatenate():
adata1 = AnnData(np.array([[1, 2, 3], [4, 5, 6]]),
{'smp_names': ['s1', 's2'],
{'obs_names': ['s1', 's2'],
'anno1': ['c1', 'c2']},
{'var_names': ['a', 'b', 'c']})
adata2 = AnnData(np.array([[1, 2, 3], [4, 5, 6]]),
{'smp_names': ['s3', 's4'],
{'obs_names': ['s3', 's4'],
'anno1': ['c3', 'c4']},
{'var_names': ['b', 'c', 'd']})
adata3 = AnnData(np.array([[1, 2, 3], [4, 5, 6]]),
{'smp_names': ['s5', 's6'],
{'obs_names': ['s5', 's6'],
'anno2': ['d3', 'd4']},
{'var_names': ['b', 'c', 'd']})
adata = adata1.concatenate([adata2, adata3])
assert adata.n_vars == 2
assert adata.smp_keys() == ['anno1', 'anno2', 'batch']
assert adata.obs_keys() == ['anno1', 'anno2', 'batch']
adata = adata1.concatenate([adata2, adata3], batch_key='batch1')
assert adata.smp_keys() == ['anno1', 'anno2', 'batch1']
assert adata.obs_keys() == ['anno1', 'anno2', 'batch1']
adata = adata1.concatenate([adata2, adata3], batch_categories=['a1', 'a2', 'a3'])
assert adata.smp['batch'].cat.categories.tolist() == ['a1', 'a2', 'a3']
assert adata.obs['batch'].cat.categories.tolist() == ['a1', 'a2', 'a3']



def test_concatenate_sparse():
from scipy.sparse import csr_matrix
adata1 = AnnData(csr_matrix([[0, 2, 3], [0, 5, 6]]),
{'smp_names': ['s1', 's2'],
{'obs_names': ['s1', 's2'],
'anno1': ['c1', 'c2']},
{'var_names': ['a', 'b', 'c']})
adata2 = AnnData(csr_matrix([[0, 2, 3], [0, 5, 6]]),
{'smp_names': ['s3', 's4'],
{'obs_names': ['s3', 's4'],
'anno1': ['c3', 'c4']},
{'var_names': ['b', 'c', 'd']})
adata3 = AnnData(csr_matrix([[1, 2, 0], [0, 5, 6]]),
{'smp_names': ['s5', 's6'],
{'obs_names': ['s5', 's6'],
'anno2': ['d3', 'd4']},
{'var_names': ['b', 'c', 'd']})
adata = adata1.concatenate([adata2, adata3])
Expand Down
4 changes: 2 additions & 2 deletions anndata/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import logging
import logging as logg
import pandas as pd


Expand Down Expand Up @@ -40,7 +40,7 @@ def make_index_unique(index, join=''):

def warn_names_duplicates(string, df):
names = 'Observation' if string == 'obs' else 'Variable'
logging.info(
logg.info(
'{} names are not unique. '
'To make them unique, call `.{}_names_make_unique()`.\n'
'Duplicates are: {}'.format(names, string, df.index.get_duplicates()))

0 comments on commit 8cabf9c

Please sign in to comment.