Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute_sparse_laplacian gives int32 vs int64 index mismatch when input is from scipy csr #509

Closed
Tracked by #677
selmanozleyen opened this issue Mar 30, 2024 · 0 comments · Fixed by #510
Closed
Tracked by #677

Comments

@selmanozleyen
Copy link
Contributor

Hi,
compute_sparse_laplacian gives int32 vs int64 index mismatch when input is from scipy csr. I am using macos with m1. I think this bug might be device specific because some tests of ottjax have similar scenarios.

To Reproduce

import jax
import jax.numpy as jnp
import jax.experimental.sparse as jesp
import scipy.sparse as sp
from ott.geometry import geodesic
import numpy as np

Here there are no errors

arr2 = sp.rand(10, 10, 0.1)
arr2 = jesp.BCOO.from_scipy_sparse(arr2)

geodesic.Geodesic.from_graph(arr2, t=1.0, directed=True)

but when I run this

arr2 = sp.rand(10, 10, 0.1)
arr2 = jesp.BCOO.from_scipy_sparse(arr2)

geodesic.Geodesic.from_graph(arr2, t=1.0, directed=True)

I get this error

{
	"name": "TypeError",
	"message": "lax.concatenate requires arguments to have the same dtypes, got int64, int32. (Tip: jnp.concatenate is a similar function that does automatic type promotion on inputs).",
	"stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], line 4
      1 arr2 = sp.rand(10, 10, 0.1)
      2 arr2 = jesp.BCOO.from_scipy_sparse(arr2)
----> 4 geodesic.Geodesic.from_graph(arr2, t=1.0, directed=True)

File ~/Documents/projects/ott/src/ott/geometry/geodesic.py:111, in Geodesic.from_graph(cls, G, t, eigval, order, directed, normalize, rng, **kwargs)
    108   t = (jnp.sum(G) / jnp.sum(G > 0.0)) ** 2
    110 if isinstance(G, jesp.BCOO):
--> 111   laplacian = compute_sparse_laplacian(G, normalize)
    112 else:
    113   laplacian = compute_dense_laplacian(G, normalize)

File ~/Documents/projects/ott/src/ott/geometry/geodesic.py:242, in compute_sparse_laplacian(G, normalize)
    240 data_degree, ixs = G.sum(1).todense(), jnp.arange(n)
    241 degree = jesp.BCOO((data_degree, jnp.c_[ixs, ixs]), shape=(n, n))
--> 242 laplacian = degree - G
    243 if normalize:
    244   laplacian = normalize_laplacian(laplacian, data_degree)

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/experimental/sparse/transform.py:466, in _sparsify_with_interpreter.<locals>.wrapped(*args, **params)
    464 spenv = SparsifyEnv()
    465 spvalues = arrays_to_spvalues(spenv, args)
--> 466 spvalues_out, out_tree = f_raw(spenv, *spvalues, **params)
    467 out = spvalues_to_arrays(spenv, spvalues_out)
    468 return tree_unflatten(out_tree, out)

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/experimental/sparse/transform.py:451, in sparsify_raw.<locals>.wrapped(spenv, *spvalues, **params)
    449 wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
    450 jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
--> 451 result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
    452 if len(out_avals_flat) != len(result):
    453   raise Exception(\"Internal: eval_sparse does not return expected number of arguments. \"
    454                   \"Got {result} for avals {out_avals_flat}\")

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/experimental/sparse/transform.py:428, in eval_sparse(jaxpr, consts, spvalues, spenv)
    426   if prim not in sparse_rules_bcoo:
    427     _raise_unimplemented_primitive(prim)
--> 428   out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params)
    429 else:
    430   out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/experimental/sparse/transform.py:651, in _sub_sparse(spenv, *spvalues)
    649 X, Y = spvalues
    650 if X.is_sparse() and Y.is_sparse():
--> 651   return _add_sparse(spenv, X, *sparse_rules_bcoo[lax.neg_p](spenv, Y))
    652 else:
    653   raise NotImplementedError(\"Subtraction between sparse and dense array.\")

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/experimental/sparse/transform.py:627, in _add_sparse(spenv, *spvalues)
    625   raise NotImplementedError(\"Addition between sparse matrices with different batch/dense dimensions.\")
    626 else:
--> 627   out_indices = lax.concatenate([spenv.indices(X), spenv.indices(Y)], dimension=spenv.indices(X).ndim - 2)
    628   out_data = lax.concatenate([spenv.data(X), spenv.data(Y)], dimension=spenv.indices(X).ndim - 2)
    629   out_spvalue = spenv.sparse(X.shape, out_data, out_indices)

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/_src/lax/lax.py:622, in concatenate(operands, dimension)
    620   if isinstance(op, Array):
    621     return type_cast(Array, op)
--> 622 return concatenate_p.bind(*operands, dimension=dimension)

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/_src/core.py:444, in Primitive.bind(self, *args, **params)
    441 def bind(self, *args, **params):
    442   assert (not config.enable_checks.value or
    443           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 444   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/_src/core.py:447, in Primitive.bind_with_trace(self, trace, args, params)
    446 def bind_with_trace(self, trace, args, params):
--> 447   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    448   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/_src/core.py:935, in EvalTrace.process_primitive(self, primitive, tracers, params)
    934 def process_primitive(self, primitive, tracers, params):
--> 935   return primitive.impl(*tracers, **params)

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/_src/dispatch.py:87, in apply_primitive(prim, *args, **params)
     85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     86 try:
---> 87   outs = fun(*args)
     88 finally:
     89   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 20 frame]

File ~/mambaforge/envs/moscot/lib/python3.11/site-packages/jax/_src/lax/lax.py:4838, in check_same_dtypes(name, *avals)
   4836   equiv = _JNP_FUNCTION_EQUIVALENTS[name]
   4837   msg += f\" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs).\"
-> 4838 raise TypeError(msg.format(name, \", \".join(str(a.dtype) for a in avals)))

TypeError: lax.concatenate requires arguments to have the same dtypes, got int64, int32. (Tip: jnp.concatenate is a similar function that does automatic type promotion on inputs)."
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant