Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError -> with: sinkhorn_divergence - when: passing in sinkhorn_kwargs={''rank"=#someInteger} #485

Closed
Farbodch opened this issue Jan 23, 2024 · 3 comments

Comments

@Farbodch
Copy link

Describe the bug
When

ot = sinkhorn_divergence.sinkhorn_divergence(
    geom,
    x=geom.x,
    y=geom.y,
    static_b=True,
    sinkhorn_kwargs={"rank":10,"initializer":'random'})
return ot.divergence, ot

is called (through jax.jit(jax.value_and_grad(...))), the expected result is for the low-rank sinkhorn (LRSinkhorn) solver to be used. However, an AttributeError is thrown instead:

    179     out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
    180     0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
    181 )
    182 out = (out_xy, out_xx, out_yy)
    183 return SinkhornDivergenceOutput(
--> 184     div, tuple([s.f, s.g] for s in out),
    185     (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
    186     tuple(s.converged for s in out), a, b
    187 )

AttributeError: 'LRSinkhornOutput' object has no attribute 'f'

Full Error Output

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1262], line 109
--> 109 geom_grad = get_geom(my_source, my_target)

Cell In[1262], line 29, in get_geom(source, target)
     27 def get_geom(source, target):
     28     geom = pointcloud.PointCloud(source, target, epsilon=epsilon)
---> 29     (cost, ot), geom_g = cost_fn_vg(geom)
     30     assert ot.converged

    [... skipping hidden 20 frame]

Cell In[1262], line 9, in sink_div(geom)
      4 def sink_div(geom):
      5     """Return the Sinkhorn divergence cost and OT output given a geometry.
      6     Since y is fixed, we can use static_b=True to avoid computing
      7     the OT(b, b) term."""
----> 9     ot = sinkhorn_divergence.sinkhorn_divergence(
     10         geom,
     11         x=geom.x,
     12         y=geom.y,
     13         static_b=True,
     14         sinkhorn_kwargs={"rank":10,"initializer":'random'})
     15     return ot.divergence, ot

File ~/Miniconda3-py311_23.5.2-0-Linux-x86_64/envs/env1/lib/python3.11/site-packages/ott/tools/sinkhorn_divergence.py:103, in sinkhorn_divergence(geom, a, b, sinkhorn_kwargs, static_b, share_epsilon, symmetric_sinkhorn, *args, **kwargs)
    101 a = jnp.ones(num_a) / num_a if a is None else a
    102 b = jnp.ones(num_b) / num_b if b is None else b
--> 103 return _sinkhorn_divergence(
    104     geom_xy,
    105     geom_x,
    106     geom_y,
    107     a=a,
    108     b=b,
    109     symmetric_sinkhorn=symmetric_sinkhorn,
    110     **sinkhorn_kwargs
    111 )

File ~/Miniconda3-py311_23.5.2-0-Linux-x86_64/envs/env1/lib/python3.11/site-packages/ott/tools/sinkhorn_divergence.py:184, in _sinkhorn_divergence(geometry_xy, geometry_xx, geometry_yy, a, b, symmetric_sinkhorn, **kwargs)
    178 div = (
    179     out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
    180     0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
    181 )
    182 out = (out_xy, out_xx, out_yy)
    183 return SinkhornDivergenceOutput(
--> 184     div, tuple([s.f, s.g] for s in out),
    185     (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
    186     tuple(s.converged for s in out), a, b
    187 )

File ~/Miniconda3-py311_23.5.2-0-Linux-x86_64/envs/env1/lib/python3.11/site-packages/ott/tools/sinkhorn_divergence.py:184, in <genexpr>(.0)
    178 div = (
    179     out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
    180     0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
    181 )
    182 out = (out_xy, out_xx, out_yy)
    183 return SinkhornDivergenceOutput(
--> 184     div, tuple([s.f, s.g] for s in out),
    185     (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
    186     tuple(s.converged for s in out), a, b
    187 )

AttributeError: 'LRSinkhornOutput' object has no attribute 'f'

To Reproduce
Relevant code snippet used (to reproduce the behavior):

def sink_div(geom):
    """Return the Sinkhorn divergence cost and OT output given a geometry.
    Since y is fixed, we can use static_b=True to avoid computing
    the OT(b, b) term."""

    ot = sinkhorn_divergence.sinkhorn_divergence(
        geom,
        x=geom.x,
        y=geom.y,
        static_b=True,
        sinkhorn_kwargs={"rank":10,"initializer":'random'})
    return ot.divergence, ot

cost_fn_vg = jax.jit(jax.value_and_grad(sink_div, has_aux=True))


get_geom(source, target):
    geom = pointcloud.PointCloud(source, target, epsilon=epsilon)
    (cost, ot), geom_g = cost_fn_vg(geom)
    assert ot.converged

    return geom_g.x

Additional information (please complete the following information):
Overall script works as expected when sink_div (as stated above) is replaced with direct lowrank sinkhorn solver (sink_lr_cost below):

def sink_lr_cost(geom):
    """Return the OT cost and OT output given a geometry"""
    ot = sinkhorn_lr.LRSinkhorn(rank=10, initializer='random')(linear_problem.LinearProblem(geom))
    return ot.reg_ot_cost, ot

cost_fn_vg = jax.jit(jax.value_and_grad(sink_lr_cost, has_aux=True))

System/Environment information

  • Python: 3.11.5
  • jax.devices() -> [cuda(id=0] (gpu)
@Farbodch Farbodch changed the title AttributeError - with: sinkhorn_divergence - when: passing in sinkhorn_kwargs={''rank"=#someInteger} AttributeError -> with: sinkhorn_divergence - when: passing in sinkhorn_kwargs={''rank"=#someInteger} Jan 23, 2024
@marcocuturi
Copy link
Contributor

Hi @Farbodch
sorry about this, we did not expect this usage, but this is indeed very valid, specially since we explored this here :https://proceedings.neurips.cc/paper_files/paper/2022/hash/2d69e771d9f274f7c624198ea74f5b98-Abstract-Conference.html

essentially this is just an API bug, and should work all right, it's just that we shouldn't try to pull the .f and .g potentials from LR sinkhorn output in that case.

How urgent is this? If you need this for ICML let us know, we can come with a slightly dumb patch.

@Farbodch
Copy link
Author

Hi @marcocuturi

Thank you for the reply! It’s not super urgent/it won’t make it to ICML, but I would greatly appreciate any updates!

@michalk8
Copy link
Collaborator

michalk8 commented Aug 9, 2024

Hi @Farbodch , it's finally implemented!

closed via #568

@michalk8 michalk8 closed this as completed Aug 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants