Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move datasets from mong_gap and gromov_wasserstein notebooks to datasets.py file; add conditional distribution sampling to the gromov_wasserstein notebook. #467

Closed

Conversation

theouscidda6
Copy link
Contributor

This pull-request proposes to:

  • Move datasets from notebooks Monge_Gap.ipynb and gromov_wasserstein.ipynb into file datasets.py. Therefore, we create two new Dataset classes: SklearnDistribution and SortedSpiral.

  • Update the end of gromov_wasserstein.ipynb with a sampling of the conditional distributions of the discrete Gromov-Wasserstein coupling instead of recovering a one-to-one matching associating each point with the one it is most coupled.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link

codecov bot commented Nov 21, 2023

Codecov Report

Merging #467 (82b8fc3) into main (b50719b) will decrease coverage by 0.60%.
Report is 1 commits behind head on main.
The diff coverage is 40.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #467      +/-   ##
==========================================
- Coverage   90.45%   89.86%   -0.60%     
==========================================
  Files          58       58              
  Lines        6349     6423      +74     
  Branches      614      903     +289     
==========================================
+ Hits         5743     5772      +29     
- Misses        467      512      +45     
  Partials      139      139              
Files Coverage Δ
src/ott/datasets.py 58.40% <40.00%> (-36.47%) ⬇️

@marcocuturi
Copy link
Contributor

is this ready for review?

@theouscidda6
Copy link
Contributor Author

theouscidda6 commented Nov 28, 2023

is this ready for review?

I think so. The previous doc errors were due to the hackathon refactoring. It should be fine now.

@michalk8
Copy link
Collaborator

I think so. The previous doc errors were due to the hackathon refactoring. It should be fine now.

Should be fixed soon! once it's done, can you please rebase? Will ping you here.

@theouscidda6
Copy link
Contributor Author

I think so. The previous doc errors were due to the hackathon refactoring. It should be fine now.

Should be fixed soon! once it's done, can you please rebase? Will ping you here.

Of course, I'll be waiting for your feedback.

@@ -1,5 +1,15 @@
{
Copy link
Contributor

@marcocuturi marcocuturi Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove


Reply via ReviewNB

@@ -1,5 +1,15 @@
{
Copy link
Contributor

@marcocuturi marcocuturi Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update


Reply via ReviewNB

@@ -1,5 +1,15 @@
{
Copy link
Contributor

@marcocuturi marcocuturi Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #10.        model = models.models.MLP(

maybe import models to avoid models.models ?


Reply via ReviewNB

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it should be from ott.neural import models; model = models.MLP

Copy link

review-notebook-app bot commented Nov 28, 2023

View / edit / reply to this conversation on ReviewNB

marcocuturi commented on 2023-11-28T15:17:14Z
----------------------------------------------------------------

unfortunately this notation is not introduced elsewhere? either add equation to define it or remove


Copy link

review-notebook-app bot commented Nov 28, 2023

View / edit / reply to this conversation on ReviewNB

marcocuturi commented on 2023-11-28T15:17:15Z
----------------------------------------------------------------

maybe a comment on the fact that we expect somewhat the optimal alignment matrix to be close to identity.


Copy link

review-notebook-app bot commented Nov 28, 2023

View / edit / reply to this conversation on ReviewNB

marcocuturi commented on 2023-11-28T15:17:16Z
----------------------------------------------------------------

same, here $\pi^\star_\varepsilon$ needs to be defined somewhere.


Copy link

review-notebook-app bot commented Nov 28, 2023

View / edit / reply to this conversation on ReviewNB

marcocuturi commented on 2023-11-28T15:17:17Z
----------------------------------------------------------------

can we reduce the alpha of gray points?


@marcocuturi
Copy link
Contributor

I think all is fixed, so should be ready to be ready to merge once modifications are taken into account!

@michalk8 michalk8 added the enhancement New feature or request label Dec 4, 2023
@michalk8 michalk8 self-requested a review December 4, 2023 11:17
Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @theouscidda6 , I will review the changes in the notebook later!

@@ -87,6 +87,7 @@ docs = [
"sphinxcontrib-bibtex>=2.5.0",
"sphinxcontrib-spelling>=7.7.0",
"myst-nb>=0.17.1",
"scikit-learn>=1.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary here, as scikit-learn is not needed when building the docs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in the normal requirements instead.

"""Random sample generator from a Sklearn distribution.

Returns:
A generator of samples from the Sklearn distribution.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dedent please.

return self._create_sample_generators()

def __post_init__(self):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove empty space.

"""
return self._create_sample_generators()

def __post_init__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing -> None:

init_rng: jax.Array
dim_data: int = 2
theta_rotation: float = 0.0
offset: Optional[jnp.ndarray] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this doesn't need to be None, can be just by default 0; would use offset: Union[float, jnp.ndarray] = 0.0

"""
return self._create_sample_generators()

def __post_init__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check in this function whether self.name is in correct values as

assert self.name in ("moon", etc.), self.na,me

Returns:
A generator of samples from the Sklearn distribution.
"""
return self._create_sample_generators()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also consider removing the private function and implementing everything directly in the __iter__ method.

@@ -105,6 +110,168 @@ def _create_sample_generators(self) -> Iterator[jnp.array]:
yield samples


@dataclasses.dataclass
class SklearnDistribution:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this and all the newly created functions to the docs in docs/datasets.rst; I think you will need to create this file - please have a look at docs/utils.rst how it's done.

random_state=seed,
noise=self.std_noise,
)
samples = x[:, [2, 0]] if self.dim_data == 2 else x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for swapping the axes? If yes, please add a comment.

)
),
)
dim_data = 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this hardcoded here? I would consider removing this and just returning the 2 datasets.

valid_batch_size: int = 256,
rng: Optional[jax.Array] = None,
) -> Tuple[Dataset, Dataset, int]:
"""Sklearn samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link is no longer correct.

@@ -1,5 +1,15 @@
{
Copy link
Collaborator

@michalk8 michalk8 Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #15.    from ott.neural.solvers import losses, map_estimator

This needs to be adjusted as

from ott.neural import losses

from ott.neural.solvers import map_estimator


Reply via ReviewNB

@@ -1,5 +1,15 @@
{
Copy link
Collaborator

@michalk8 michalk8 Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Links, such as ott.solvers.nn.losses.monge_gap,MapEstimator and W2NeuralDual need to be adjusted.


Reply via ReviewNB

@@ -1,5 +1,15 @@
{
Copy link
Collaborator

@michalk8 michalk8 Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this please


Reply via ReviewNB

Copy link

review-notebook-app bot commented Dec 4, 2023

View / edit / reply to this conversation on ReviewNB

michalk8 commented on 2023-12-04T14:52:21Z
----------------------------------------------------------------

mathcing -> matching


@michalk8
Copy link
Collaborator

@theouscidda6 what's the status of this PR?

@michalk8
Copy link
Collaborator

michalk8 commented Jun 5, 2024

@theouscidda6 because of the heave refactoring done in #466 , I will be closing this issue.
I will create later a new issue to unify the example datasets present in ott.datasets and the dataset class/dataloaders in ott.neural.datasets.

@michalk8 michalk8 closed this Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants