Skip to content

Commit

Permalink
fixed bug in concatenate, present since releases 0.5.2, 0.5.3, 0.5.4
Browse files Browse the repository at this point in the history
  • Loading branch information
falexwolf committed Mar 5, 2018
1 parent 6715787 commit 4238647
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
32 changes: 22 additions & 10 deletions anndata/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,9 +1301,12 @@ def copy(self, filename=None):
return AnnData(filename=filename)

def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories=None, index_unique=None):
"""Concatenate along the observations axis after intersecting the variables names.
"""Concatenate along the observations axis.
The `.var`, `.varm`, and `.uns` attributes of the passed adatas are ignored.
The `.uns` and `.varm` attributes of the passed `adatas` are ignored.
If you use `join='outer'`, then note that this fills 0s for data that is
non-present. Use this with care.
Parameters
----------
Expand Down Expand Up @@ -1337,7 +1340,7 @@ def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories
>>> {'anno2': ['d3', 'd4']},
>>> {'var_names': ['b', 'c', 'd']})
>>>
>>> adata = adata1.concatenate(adata2, adata3)
>>> adata = adata1.concatenate(adata2, adata3, index_unique='-')
>>> adata.X
[[ 2. 3.]
[ 5. 6.]
Expand Down Expand Up @@ -1372,9 +1375,17 @@ def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories
'Making variable names unique for controlled concatenation.')
printed_info = True

# define variable names of joint AnnData
mergers = dict(inner=set.intersection, outer=set.union)
var_names = pd.Index(reduce(mergers[join], (set(ad.var_names) for ad in all_adatas)))

var_names_reduce = reduce(mergers[join], (set(ad.var_names) for ad in all_adatas))
# restore order of initial var_names, append non-sortable names at the end
var_names = []
for v in all_adatas[0].var_names:
if v in var_names_reduce:
var_names.append(v)
var_names_reduce.remove(v) # update the set
var_names = pd.Index(var_names + list(var_names_reduce))

if batch_categories is None:
categories = [str(i) for i, _ in enumerate(all_adatas)]
elif len(batch_categories) == len(all_adatas):
Expand All @@ -1392,11 +1403,11 @@ def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories
obs_i = 0 # start of next adata’s observations in X
out_obss = []
for i, ad in enumerate(all_adatas):
vars_ad_in_res = var_names.isin(ad.var_names)
vars_res_in_ad = ad.var_names.isin(var_names)
vars_intersect = [v for v in var_names if v in ad.var_names]

# X
X[obs_i:obs_i+ad.n_obs, vars_ad_in_res] = ad.X[:, vars_res_in_ad]
X[obs_i:obs_i+ad.n_obs,
var_names.isin(vars_intersect)] = ad[:, vars_intersect].X
obs_i += ad.n_obs

# obs
Expand All @@ -1412,13 +1423,14 @@ def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories
out_obss.append(obs)

# var
var.loc[vars_ad_in_res, ad.var.columns] = ad.var.loc[vars_res_in_ad, :]
# potential add additional columns
var.loc[vars_intersect, ad.var.columns] = ad.var.loc[vars_intersect, :]

obs = pd.concat(out_obss)
uns = all_adatas[0].uns
obsm = np.concatenate([ad.obsm for ad in all_adatas])
varm = self.varm # TODO

new_adata = AnnData(X, obs, var, uns, obsm, None, filename=self.filename)
if not obs.index.is_unique:
logg.info(
Expand Down
7 changes: 4 additions & 3 deletions anndata/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,20 @@ def test_concatenate():
adata2 = AnnData(np.array([[1, 2, 3], [4, 5, 6]]),
{'obs_names': ['s3', 's4'],
'anno1': ['c3', 'c4']},
{'var_names': ['b', 'c', 'd']})
{'var_names': ['d', 'c', 'b']})
adata3 = AnnData(np.array([[1, 2, 3], [4, 5, 6]]),
{'obs_names': ['s5', 's6'],
'anno2': ['d3', 'd4']},
{'var_names': ['b', 'c', 'd']})
{'var_names': ['d', 'c', 'b']})
adata = adata1.concatenate(adata2, adata3)
assert adata.n_vars == 2
assert adata.obs_keys() == ['anno1', 'anno2', 'batch']
adata = adata1.concatenate(adata2, adata3, batch_key='batch1')
assert adata.obs_keys() == ['anno1', 'anno2', 'batch1']
adata = adata1.concatenate(adata2, adata3, batch_categories=['a1', 'a2', 'a3'])
assert adata.obs['batch'].cat.categories.tolist() == ['a1', 'a2', 'a3']

assert adata.var_names.tolist() == ['b', 'c']


def test_concatenate_sparse():
from scipy.sparse import csr_matrix
Expand Down
4 changes: 4 additions & 0 deletions docs/release_notes.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
See all releases `here <https://github.com/theislab/anndata/releases>`_. The following lists selected improvements.

Warning: there has been a bug in :func:`~anndata.AnnData.concatenate` in
versions 0.5.2, 0.5.3 and 0.5.4: variable names were not assigned correctly. Use
version 0.5.5.


**February 9, 2018**: version 0.5

Expand Down

0 comments on commit 4238647

Please sign in to comment.