Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ Introduce initialization methods for Sinkhorn #98

Merged
merged 49 commits into from
Aug 17, 2022
Merged

Feature/ Introduce initialization methods for Sinkhorn #98

merged 49 commits into from
Aug 17, 2022

Conversation

JTT94
Copy link
Collaborator

@JTT94 JTT94 commented Jul 4, 2022

  • Examples here: https://colab.research.google.com/drive/1vncmDEr3t6_OKfVC0PJin8PIRBViPyO0?usp=sharing

  • Initialization methods stored in /core/initializers.py

    • Each initializer inherits from base class SinkhornInitializer
      • Class is more flexible in case of neural network initialisers in future
      • Currently added:
        • sorting initializer (n=m) for any cost
        • Gaussian initializer for squared ground cost added, only works for point cloud as need access to x
    • Each initializer has methods init_dual_a and init_dual_b
        def init_dual_a(
              self, ot_problem: LinearProblem, lse_mode: bool = True
          ) -> jnp.ndarray:
            """Initialzation for Sinkhorn potential f.
    
    • The base class also holds default behaviour, initialization for 0 and 1 depending on log lse_mode, also handling entries for 0 weights
  • Modification to Sinkhorn api,

    • pass instantiated initializer to Sinkhorn
    Sinkhorn(
       lse_mode=lse_mode,
       threshold=threshold,
       norm_error=norm_error,
       inner_iterations=inner_iterations,
       min_iterations=min_iterations,
       max_iterations=max_iterations,
       momentum=momentum_lib.Momentum(start=chg_momentum_from, value=momentum),
       anderson=anderson,
       implicit_diff=implicit_diff,
       parallel_dual_updates=parallel_dual_updates,
       use_danskin=use_danskin,
       potential_initializer=potential_initializer,
       jit=jit
    )
    
    • Can be passed to sinkhorn functional wrapper
    gaus_init = init_lib.GaussianInitializer()
    
    @jax.jit
    def run_sinkhorn_gaus_init(x, y, a=None, b=None):
        sink_kwargs = {'jit': True, 
                    'threshold': 0.001, 
                    'max_iterations': 10**5, 
                    'potential_initializer': gaus_init}
                    
        geom_kwargs = {'epsilon': 0.01}
        geom = PointCloud(x, y, **geom_kwargs)
        out = sinkhorn(geom, a=a, b=b, **sink_kwargs)
        return out
    

@JTT94 JTT94 marked this pull request as ready for review July 4, 2022 13:20
@JTT94 JTT94 marked this pull request as draft July 4, 2022 13:35
@JTT94 JTT94 marked this pull request as ready for review July 4, 2022 13:45
@michalk8 michalk8 added the enhancement New feature or request label Jul 5, 2022
@michalk8
Copy link
Collaborator

@marcocuturi could you please enable the readthedocs build?

ott/core/initializers.py Outdated Show resolved Hide resolved
@marcocuturi
Copy link
Contributor

@michalk8 : i had some issues with the webhook to enable recompiling docs on PR. I will look into this soon.

@michalk8
Copy link
Collaborator

@JTT94 could you please fix the bibtex file and deal with the rest of the comments? After that, think we can merge.

@JTT94
Copy link
Collaborator Author

JTT94 commented Aug 17, 2022

Thanks for the feedback, this is good to be merged

ott/core/initializers.py Outdated Show resolved Hide resolved
ott/core/initializers.py Outdated Show resolved Hide resolved
ott/core/initializers.py Outdated Show resolved Hide resolved
@michalk8
Copy link
Collaborator

Also just noticed that the initializers are not registered as PyTrees, i.e.

@jax.jit
def test():
    init = ott.core.initializers.SortingInitializer()
    return init

fails. Could you please register them using @jax.tree_util.register_pytree_node_class?

ott/core/initializers.py Outdated Show resolved Hide resolved
@michalk8
Copy link
Collaborator

Thanks a lot @JTT94 , could you please include the last change mentioned here: #98 (comment) ? It's more efficient than doing lax.cond. After that I will happily merge this.

Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JTT94, LGTM, merging this!

@michalk8 michalk8 merged commit eac3315 into ott-jax:main Aug 17, 2022
@JTT94
Copy link
Collaborator Author

JTT94 commented Aug 19, 2022

Thanks @JTT94, LGTM, merging this!

Great! And thanks for your patience with feedback, I appreciate it

@JTT94 JTT94 mentioned this pull request Aug 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants