You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ot = sinkhorn_divergence.sinkhorn_divergence(
geom,
x=geom.x,
y=geom.y,
static_b=True,
sinkhorn_kwargs={"rank":10,"initializer":'random'})
return ot.divergence, ot
is called (through jax.jit(jax.value_and_grad(...))), the expected result is for the low-rank sinkhorn (LRSinkhorn) solver to be used. However, an AttributeError is thrown instead:
179 out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
180 0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
181 )
182 out = (out_xy, out_xx, out_yy)
183 return SinkhornDivergenceOutput(
--> 184 div, tuple([s.f, s.g] for s in out),
185 (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
186 tuple(s.converged for s in out), a, b
187 )
AttributeError: 'LRSinkhornOutput' object has no attribute 'f'
Full Error Output
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[1262], line 109
--> 109 geom_grad = get_geom(my_source, my_target)
Cell In[1262], line 29, in get_geom(source, target)
27 def get_geom(source, target):
28 geom = pointcloud.PointCloud(source, target, epsilon=epsilon)
---> 29 (cost, ot), geom_g = cost_fn_vg(geom)
30 assert ot.converged
[... skipping hidden 20 frame]
Cell In[1262], line 9, in sink_div(geom)
4 def sink_div(geom):
5 """Return the Sinkhorn divergence cost and OT output given a geometry.
6 Since y is fixed, we can use static_b=True to avoid computing
7 the OT(b, b) term."""
----> 9 ot = sinkhorn_divergence.sinkhorn_divergence(
10 geom,
11 x=geom.x,
12 y=geom.y,
13 static_b=True,
14 sinkhorn_kwargs={"rank":10,"initializer":'random'})
15 return ot.divergence, ot
File ~/Miniconda3-py311_23.5.2-0-Linux-x86_64/envs/env1/lib/python3.11/site-packages/ott/tools/sinkhorn_divergence.py:103, in sinkhorn_divergence(geom, a, b, sinkhorn_kwargs, static_b, share_epsilon, symmetric_sinkhorn, *args, **kwargs)
101 a = jnp.ones(num_a) / num_a if a is None else a
102 b = jnp.ones(num_b) / num_b if b is None else b
--> 103 return _sinkhorn_divergence(
104 geom_xy,
105 geom_x,
106 geom_y,
107 a=a,
108 b=b,
109 symmetric_sinkhorn=symmetric_sinkhorn,
110 **sinkhorn_kwargs
111 )
File ~/Miniconda3-py311_23.5.2-0-Linux-x86_64/envs/env1/lib/python3.11/site-packages/ott/tools/sinkhorn_divergence.py:184, in _sinkhorn_divergence(geometry_xy, geometry_xx, geometry_yy, a, b, symmetric_sinkhorn, **kwargs)
178 div = (
179 out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
180 0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
181 )
182 out = (out_xy, out_xx, out_yy)
183 return SinkhornDivergenceOutput(
--> 184 div, tuple([s.f, s.g] for s in out),
185 (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
186 tuple(s.converged for s in out), a, b
187 )
File ~/Miniconda3-py311_23.5.2-0-Linux-x86_64/envs/env1/lib/python3.11/site-packages/ott/tools/sinkhorn_divergence.py:184, in <genexpr>(.0)
178 div = (
179 out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
180 0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
181 )
182 out = (out_xy, out_xx, out_yy)
183 return SinkhornDivergenceOutput(
--> 184 div, tuple([s.f, s.g] for s in out),
185 (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
186 tuple(s.converged for s in out), a, b
187 )
AttributeError: 'LRSinkhornOutput' object has no attribute 'f'
To Reproduce
Relevant code snippet used (to reproduce the behavior):
def sink_div(geom):
"""Return the Sinkhorn divergence cost and OT output given a geometry.
Since y is fixed, we can use static_b=True to avoid computing
the OT(b, b) term."""
ot = sinkhorn_divergence.sinkhorn_divergence(
geom,
x=geom.x,
y=geom.y,
static_b=True,
sinkhorn_kwargs={"rank":10,"initializer":'random'})
return ot.divergence, ot
cost_fn_vg = jax.jit(jax.value_and_grad(sink_div, has_aux=True))
get_geom(source, target):
geom = pointcloud.PointCloud(source, target, epsilon=epsilon)
(cost, ot), geom_g = cost_fn_vg(geom)
assert ot.converged
return geom_g.x
Additional information (please complete the following information):
Overall script works as expected when sink_div (as stated above) is replaced with direct lowrank sinkhorn solver (sink_lr_cost below):
def sink_lr_cost(geom):
"""Return the OT cost and OT output given a geometry"""
ot = sinkhorn_lr.LRSinkhorn(rank=10, initializer='random')(linear_problem.LinearProblem(geom))
return ot.reg_ot_cost, ot
cost_fn_vg = jax.jit(jax.value_and_grad(sink_lr_cost, has_aux=True))
System/Environment information
Python: 3.11.5
jax.devices() -> [cuda(id=0] (gpu)
The text was updated successfully, but these errors were encountered:
Farbodch
changed the title
AttributeError - with: sinkhorn_divergence - when: passing in sinkhorn_kwargs={''rank"=#someInteger}
AttributeError -> with: sinkhorn_divergence - when: passing in sinkhorn_kwargs={''rank"=#someInteger}
Jan 23, 2024
essentially this is just an API bug, and should work all right, it's just that we shouldn't try to pull the .f and .g potentials from LR sinkhorn output in that case.
How urgent is this? If you need this for ICML let us know, we can come with a slightly dumb patch.
Describe the bug
When
is called (through
jax.jit(jax.value_and_grad(...))
), the expected result is for the low-rank sinkhorn (LRSinkhorn) solver to be used. However, an AttributeError is thrown instead:Full Error Output
To Reproduce
Relevant code snippet used (to reproduce the behavior):
Additional information (please complete the following information):
Overall script works as expected when
sink_div
(as stated above) is replaced with direct lowrank sinkhorn solver (sink_lr_cost
below):System/Environment information
The text was updated successfully, but these errors were encountered: