Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ Introduce initialization methods for Sinkhorn #98

Merged
merged 49 commits into from
Aug 17, 2022
Merged

Feature/ Introduce initialization methods for Sinkhorn #98

merged 49 commits into from
Aug 17, 2022

Conversation

JTT94
Copy link
Collaborator

@JTT94 JTT94 commented Jul 4, 2022

  • Examples here: https://colab.research.google.com/drive/1vncmDEr3t6_OKfVC0PJin8PIRBViPyO0?usp=sharing

  • Initialization methods stored in /core/initializers.py

    • Each initializer inherits from base class SinkhornInitializer
      • Class is more flexible in case of neural network initialisers in future
      • Currently added:
        • sorting initializer (n=m) for any cost
        • Gaussian initializer for squared ground cost added, only works for point cloud as need access to x
    • Each initializer has methods init_dual_a and init_dual_b
        def init_dual_a(
              self, ot_problem: LinearProblem, lse_mode: bool = True
          ) -> jnp.ndarray:
            """Initialzation for Sinkhorn potential f.
    
    • The base class also holds default behaviour, initialization for 0 and 1 depending on log lse_mode, also handling entries for 0 weights
  • Modification to Sinkhorn api,

    • pass instantiated initializer to Sinkhorn
    Sinkhorn(
       lse_mode=lse_mode,
       threshold=threshold,
       norm_error=norm_error,
       inner_iterations=inner_iterations,
       min_iterations=min_iterations,
       max_iterations=max_iterations,
       momentum=momentum_lib.Momentum(start=chg_momentum_from, value=momentum),
       anderson=anderson,
       implicit_diff=implicit_diff,
       parallel_dual_updates=parallel_dual_updates,
       use_danskin=use_danskin,
       potential_initializer=potential_initializer,
       jit=jit
    )
    
    • Can be passed to sinkhorn functional wrapper
    gaus_init = init_lib.GaussianInitializer()
    
    @jax.jit
    def run_sinkhorn_gaus_init(x, y, a=None, b=None):
        sink_kwargs = {'jit': True, 
                    'threshold': 0.001, 
                    'max_iterations': 10**5, 
                    'potential_initializer': gaus_init}
                    
        geom_kwargs = {'epsilon': 0.01}
        geom = PointCloud(x, y, **geom_kwargs)
        out = sinkhorn(geom, a=a, b=b, **sink_kwargs)
        return out
    

@JTT94 JTT94 marked this pull request as ready for review July 4, 2022 13:20
@JTT94 JTT94 marked this pull request as draft July 4, 2022 13:35
@JTT94 JTT94 marked this pull request as ready for review July 4, 2022 13:45
@michalk8 michalk8 added the enhancement New feature or request label Jul 5, 2022
ott/core/initializers.py Outdated Show resolved Hide resolved
@marcocuturi
Copy link
Contributor

@michalk8 : i had some issues with the webhook to enable recompiling docs on PR. I will look into this soon.

@michalk8
Copy link
Collaborator

@JTT94 could you please fix the bibtex file and deal with the rest of the comments? After that, think we can merge.

@JTT94
Copy link
Collaborator Author

JTT94 commented Aug 17, 2022

Thanks for the feedback, this is good to be merged

ott/core/initializers.py Outdated Show resolved Hide resolved
ott/core/initializers.py Outdated Show resolved Hide resolved
ott/core/initializers.py Outdated Show resolved Hide resolved
@michalk8
Copy link
Collaborator

Also just noticed that the initializers are not registered as PyTrees, i.e.

@jax.jit
def test():
    init = ott.core.initializers.SortingInitializer()
    return init

fails. Could you please register them using @jax.tree_util.register_pytree_node_class?

ott/core/initializers.py Outdated Show resolved Hide resolved
@michalk8
Copy link
Collaborator

Thanks a lot @JTT94 , could you please include the last change mentioned here: #98 (comment) ? It's more efficient than doing lax.cond. After that I will happily merge this.

Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JTT94, LGTM, merging this!

@michalk8 michalk8 merged commit eac3315 into ott-jax:main Aug 17, 2022
@JTT94
Copy link
Collaborator Author

JTT94 commented Aug 19, 2022

Thanks @JTT94, LGTM, merging this!

Great! And thanks for your patience with feedback, I appreciate it

@JTT94 JTT94 mentioned this pull request Aug 19, 2022
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
* add sorting, gaus initializers, add gaus helpers to tools

* add initialization logic to sinkhorn

* remove general ot problem type

* remove import tools.gaussian from top level

* remove problems from top level

* do not register initializer as pytree

* add initializer to make

* rename init arg to ot_problem

* rename init arg to ot_problem

* scale gaus init by 2

* typo

* add basic speed tests

* add init to transport tools wrapper, tidy docstring

* ceneter potentials in initializers

* fix lse for null weights

* fix flake8 and accidental removal

* tidy docstrings

* tidy docstrings

* docstring flake8

* flake 8 formatting

* fix typo

* fix stop gradient in Gaussian to include weights and x,y

* fix stop gradient in Gaussian to include weights and x,y

* fix docstring spaces

* feedback from initial review

* re order local functions before state init

* optional init_f in sorting init

* docstring insert line before return

* lint fix

* incorporate feedback in commit

* tidy tests, use jax.lax.cond for logic instead of if

* add docs, rename sorting initializer

* fix merge conflict

* resolve test errors in sinkhorn test

* incorporate feedback, update tests to pytest, change docstrings, introduce defaultinit class

* fix docstring spaces

* remove spaces and add bibtex

* add errors for non square cost matrix for sorting, online geoms for initializers, tests

* merge fix lint

* merge fix lint

* add initializers as pytees

* add init scaling tests

* add init scaling tests

* simplify vector update flag in sorting initializer

* Fix documentation rendering

* [ci skip] Fix typo in docs, use fixture in tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants