-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate power
in PointCloud
, introduce TICost
and use it to compute Entropic (Brenier) maps.
#167
Merged
Merged
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
3991632
deperecate `power`, introduce h maps in potentials
marcocuturi d652448
Deprecate power and introduce h function in costs.
marcocuturi 3c80c91
linter
marcocuturi 9086039
linter
marcocuturi 7fcb609
revert abstractmethod.
marcocuturi eca30bb
linter
marcocuturi 425e05b
linter
marcocuturi 220bcd8
PNorm -> SqPNorm
marcocuturi a984da0
PNorm -> SqPNorm in tests.
marcocuturi 9bf487a
another fix for abstract method.
marcocuturi 26cef9c
fix abc.abstractmethod
marcocuturi f706007
linter
marcocuturi 94f5e19
nb fix
marcocuturi a2a1830
linter
marcocuturi c3c1a6a
nb bug fix
marcocuturi b37c518
modify ipynb
marcocuturi 6ad0a05
abc.abstractmethod for RBF
marcocuturi cdbcab2
fixes and additions.
marcocuturi eb3f922
fix `cor` in neuraldual
marcocuturi d4b5161
fix in neuraldual
marcocuturi 906d042
p-norm ** p implemented, fixes.
marcocuturi b779465
various fixes. Change to `TICost`
marcocuturi f099201
various fixes
marcocuturi 9ba0879
fix nb
marcocuturi 09bffc0
last fixes.
marcocuturi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,7 +47,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: | |
pass | ||
|
||
def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float: | ||
pass | ||
raise NotImplementedError("Barycenter not yet implemented for this cost.") | ||
|
||
@classmethod | ||
def padder(cls, dim: int) -> jnp.ndarray: | ||
|
@@ -90,17 +90,88 @@ def tree_unflatten(cls, aux_data, children): | |
return cls(*children) | ||
|
||
|
||
@jax.tree_util.register_pytree_node_class | ||
class RBFCost(CostFn): | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""A radial-basis function cost class for translation invariant costs. | ||
|
||
Such costs are defined as | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
c(x,y) = h(z), where z := x-y. | ||
|
||
where h is a function strictly convex (or concave) function mapping vectors | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a small repetition here: I think you meant "where h is a strictly convex (or concave) function ...". It's minor, I know ;) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great catch! thanks. |
||
to real-values. | ||
|
||
For completeness (and differentiation using the Brenier theorem), the user | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
is also supposed to provide the Legendre transform of `h`, whose gradient (the | ||
inverse of the gradient of `h`) will be used to form a Brenier map. | ||
""" | ||
|
||
def h(self, z: jnp.ndarray) -> float: | ||
michalk8 marked this conversation as resolved.
Show resolved
Hide resolved
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pass | ||
|
||
def h_legendre(self, z: jnp.ndarray) -> float: | ||
michalk8 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pass | ||
|
||
def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: | ||
"""Evaluate h on difference between x and y.""" | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self.h(x - y) | ||
|
||
def tree_flatten(self): | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return (), None | ||
|
||
def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float: | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pass | ||
|
||
@classmethod | ||
def tree_unflatten(cls, aux_data, children): | ||
del aux_data | ||
return cls(*children) | ||
|
||
|
||
@jax.tree_util.register_pytree_node_class | ||
class SqPNorm(RBFCost): | ||
"""Squared p-norm of the difference of two vectors. | ||
|
||
For details on the derivation of the Legendre transform of the norm, see e.g. | ||
the reference :cite:`boyd:04`, p.93/94. | ||
https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
p: float | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__(self, p: float): | ||
self.p = p | ||
marcocuturi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.q = 1. / (1 - 1 / self.p) | ||
|
||
def h(self, z: jnp.ndarray) -> float: | ||
return 0.5 * jnp.linalg.norm(z, self.p) ** 2 | ||
|
||
def h_legendre(self, z: jnp.ndarray) -> float: | ||
return 0.5 * jnp.linalg.norm(z, self.q) ** 2 | ||
|
||
def tree_flatten(self): | ||
return (), (self.p,) | ||
|
||
@classmethod | ||
def tree_unflatten(cls, aux_data, children): | ||
del children | ||
return cls(aux_data[0]) | ||
|
||
|
||
@jax.tree_util.register_pytree_node_class | ||
class Euclidean(CostFn): | ||
"""Euclidean distance.""" | ||
"""Euclidean distance. | ||
|
||
Note that the Euclidean distance is not cast as a RBF cost, because this | ||
would correspond to `h = abs`, whose gradient is not invertible. | ||
""" | ||
|
||
def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: | ||
"""Compute Euclidean norm.""" | ||
return jnp.linalg.norm(x - y) | ||
|
||
|
||
@jax.tree_util.register_pytree_node_class | ||
class SqEuclidean(CostFn): | ||
class SqEuclidean(RBFCost): | ||
"""Squared Euclidean distance.""" | ||
|
||
def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: | ||
|
@@ -111,6 +182,12 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: | |
"""Compute minus twice the dot-product between vectors.""" | ||
return -2. * jnp.vdot(x, y) | ||
|
||
def h(self, z: jnp.ndarray) -> float: | ||
return jnp.sum(z ** 2) | ||
|
||
def h_legendre(self, z: jnp.ndarray) -> float: | ||
return 0.25 * jnp.sum(z ** 2) | ||
|
||
def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: | ||
"""Output barycenter of vectors when using squared-Euclidean distance.""" | ||
return jnp.average(xs, weights=weights, axis=0) | ||
|
@@ -134,9 +211,6 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: | |
# similarity is in [-1, 1], clip because of numerical imprecisions | ||
return jnp.clip(cosine_distance, 0., 2.) | ||
|
||
def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float: | ||
raise NotImplementedError("Barycenter for cosine cost not yet implemented.") | ||
|
||
@classmethod | ||
def padder(cls, dim: int) -> jnp.ndarray: | ||
return jnp.ones((1, dim)) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the same vein as
cost_fn
, I'm wondering ifPotential_t
could be renamed toPotentialFn_t
or justPotentialFn
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had the same reflex when Michal used it for the first time, but I think it makes sense :) Here it turns out this is just a type (
_t
) and can be either a vector of a function.