Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online method throws error #55

Closed
MUCDK opened this issue Dec 17, 2021 · 7 comments
Closed

Online method throws error #55

MUCDK opened this issue Dec 17, 2021 · 7 comments
Assignees
Labels
solvers interface moscot's interface with OTT solvers

Comments

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 17, 2021

After installing the requirements as instructed I get the following problem:

Regularized.fit() method does not work if "online" in geometry object is set to True

import scanpy as sc
from ott.geometry.pointcloud import PointCloud
import jax.numpy as jnp
from moscot._solver import Regularized

adata = anndata.read("/home/icb/dominik.klein/git_repos/data/adatas/adata_tedsim_8192.h5ad")
obs_var_time = "depth"
adata_source = adata[adata.obs[obs_var_time] == 11]
adata_target = adata[adata.obs[obs_var_time] == 12]

sc.pp.pca(adata_source)
sc.pp.pca(adata_target)

pointcloud_offline = PointCloud(x=jnp.asarray(adata_source.X), y=jnp.asarray(adata_target.X), online=False)
pointcloud_online = PointCloud(x=jnp.asarray(adata_source.X), y=jnp.asarray(adata_target.X), online=True)

moscot_solver = Regularized(epsilon=0.2)

moscot_solver.fit(pointcloud_offline) # works
moscot_solver.fit(pointcloud_online) # does not work
Error message:
`TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_27778/1474808108.py in <module>
      1 moscot_solver = Regularized(epsilon=0.2)
      2 
----> 3 moscot_solver.fit(pointcloud_online)

/mnt/home/icb/dominik.klein/git_repos/moscot/moscot/_solver.py in fit(self, geom, a, b, **kwargs)
     96         """
     97         geom = self._prepare_geom(geom, **kwargs)
---> 98         self._transport = Transport(geom, a=a, b=b, **self._kwargs)
     99         self._check_marginals(a, b)
    100 

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/ott/tools/transport.py in __init__(self, a, b, *args, **kwargs)
     66       self.geom = pointcloud.PointCloud(*args, **pc_kw)
     67 
---> 68     num_a, num_b = self.geom.shape
     69     self.a = jnp.ones((num_a,)) / num_a if a is None else a
     70     self.b = jnp.ones((num_b,)) / num_b if b is None else b

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/ott/geometry/geometry.py in shape(self)
    138   @property
    139   def shape(self):
--> 140     mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
    141     if mat is not None:
    142       return mat.shape

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
    110       # If no epsilon was passed on to the geometry, then assume it is one by
    111       # default.
--> 112       cost = -jnp.log(self._kernel_matrix)
    113       return cost if self._epsilon_init is None else self.epsilon * cost
    114     return self._cost_matrix

    [... skipping hidden 15 frame]

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in <lambda>(x)
    690 def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
    691   if promote_to_inexact:
--> 692     fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
    693   else:
    694     fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _promote_args_inexact(fun_name, *args)
    600 
    601   Promotes non-inexact types to an inexact type."""
--> 602   _check_arraylike(fun_name, *args)
    603   _check_no_float0s(fun_name, *args)
    604   return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
    576                     if not _arraylike(arg))
    577     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 578     raise TypeError(msg.format(fun_name, type(arg), pos))
    579 
    580 def _check_no_float0s(fun_name, *args):

TypeError: log requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.`
@Marius1311
Copy link
Collaborator

on hold.

@Marius1311
Copy link
Collaborator

@michalk8 agreed to look into this a bit.

@michalk8
Copy link
Collaborator

michalk8 commented Jan 7, 2022

The problem is here: https://github.com/theislab/moscot/blob/dev/moscot/_solver.py#L45
geom.cost_matrix is None when online=True. The Transport object than tries to access self.geometry.shape, which throws the above error. Not sure if not passing both cost/kernel is by design (or passing both and later ignoring kernel matrix), can't find any, except for coyping eps from another temp. geometry to the current one - can be useful and that's what I should've done in the above example. Imho, not really a bug, but a feature. Minimal reproducible example:

from ott.geometry.geometry import Geometry
Geometry().shape

@Marius1311
Copy link
Collaborator

Ok, should we

  • close this or
  • open an issue in OTT?

@Marius1311
Copy link
Collaborator

@MUCDK, can we move this issue to OTT and close it here?

@Marius1311
Copy link
Collaborator

Or just close if it's no longer relevant please.

@Marius1311 Marius1311 added the solvers interface moscot's interface with OTT solvers label Jan 24, 2022
@MUCDK
Copy link
Collaborator Author

MUCDK commented Jan 24, 2022

I think we can close it for now, we could think about catching the error ourselves.

@MUCDK MUCDK closed this as completed Jan 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solvers interface moscot's interface with OTT solvers
Projects
None yet
Development

No branches or pull requests

3 participants