-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move datasets from mong_gap
and gromov_wasserstein
notebooks to datasets.py
file; add conditional distribution sampling to the gromov_wasserstein
notebook.
#467
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #467 +/- ##
==========================================
- Coverage 90.45% 89.86% -0.60%
==========================================
Files 58 58
Lines 6349 6423 +74
Branches 614 903 +289
==========================================
+ Hits 5743 5772 +29
- Misses 467 512 +45
Partials 139 139
|
is this ready for review? |
I think so. The previous doc errors were due to the hackathon refactoring. It should be fine now. |
Should be fixed soon! once it's done, can you please rebase? Will ping you here. |
Of course, I'll be waiting for your feedback. |
@@ -1,5 +1,15 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1,5 +1,15 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1,5 +1,15 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #10. model = models.models.MLP(
maybe import models to avoid models.models
?
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it should be from ott.neural import models; model = models.MLP
View / edit / reply to this conversation on ReviewNB marcocuturi commented on 2023-11-28T15:17:14Z unfortunately this notation is not introduced elsewhere? either add equation to define it or remove |
View / edit / reply to this conversation on ReviewNB marcocuturi commented on 2023-11-28T15:17:15Z maybe a comment on the fact that we expect somewhat the optimal alignment matrix to be close to identity. |
View / edit / reply to this conversation on ReviewNB marcocuturi commented on 2023-11-28T15:17:16Z same, here |
View / edit / reply to this conversation on ReviewNB marcocuturi commented on 2023-11-28T15:17:17Z can we reduce the alpha of gray points? |
I think all is fixed, so should be ready to be ready to merge once modifications are taken into account! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @theouscidda6 , I will review the changes in the notebook later!
@@ -87,6 +87,7 @@ docs = [ | |||
"sphinxcontrib-bibtex>=2.5.0", | |||
"sphinxcontrib-spelling>=7.7.0", | |||
"myst-nb>=0.17.1", | |||
"scikit-learn>=1.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not necessary here, as scikit-learn
is not needed when building the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in the normal requirements instead.
"""Random sample generator from a Sklearn distribution. | ||
|
||
Returns: | ||
A generator of samples from the Sklearn distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dedent please.
return self._create_sample_generators() | ||
|
||
def __post_init__(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove empty space.
""" | ||
return self._create_sample_generators() | ||
|
||
def __post_init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing -> None:
init_rng: jax.Array | ||
dim_data: int = 2 | ||
theta_rotation: float = 0.0 | ||
offset: Optional[jnp.ndarray] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this doesn't need to be None
, can be just by default 0; would use offset: Union[float, jnp.ndarray] = 0.0
""" | ||
return self._create_sample_generators() | ||
|
||
def __post_init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a check in this function whether self.name
is in correct values as
assert self.name in ("moon", etc.), self.na,me
Returns: | ||
A generator of samples from the Sklearn distribution. | ||
""" | ||
return self._create_sample_generators() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also consider removing the private function and implementing everything directly in the __iter__
method.
@@ -105,6 +110,168 @@ def _create_sample_generators(self) -> Iterator[jnp.array]: | |||
yield samples | |||
|
|||
|
|||
@dataclasses.dataclass | |||
class SklearnDistribution: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add this and all the newly created functions to the docs in docs/datasets.rst
; I think you will need to create this file - please have a look at docs/utils.rst
how it's done.
random_state=seed, | ||
noise=self.std_noise, | ||
) | ||
samples = x[:, [2, 0]] if self.dim_data == 2 else x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for swapping the axes? If yes, please add a comment.
) | ||
), | ||
) | ||
dim_data = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this hardcoded here? I would consider removing this and just returning the 2 datasets.
valid_batch_size: int = 256, | ||
rng: Optional[jax.Array] = None, | ||
) -> Tuple[Dataset, Dataset, int]: | ||
"""Sklearn samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This link is no longer correct.
@@ -1,5 +1,15 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #15. from ott.neural.solvers import losses, map_estimator
This needs to be adjusted as
from ott.neural import losses
from ott.neural.solvers import map_estimator
Reply via ReviewNB
@@ -1,5 +1,15 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Links, such as ott.solvers.nn.losses.monge_gap
,MapEstimator and W2NeuralDual need to be adjusted.
Reply via ReviewNB
@@ -1,5 +1,15 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
View / edit / reply to this conversation on ReviewNB michalk8 commented on 2023-12-04T14:52:21Z mathcing -> matching |
@theouscidda6 what's the status of this PR? |
@theouscidda6 because of the heave refactoring done in #466 , I will be closing this issue. |
This pull-request proposes to:
Move datasets from notebooks
Monge_Gap.ipynb
andgromov_wasserstein.ipynb
into filedatasets.py
. Therefore, we create two newDataset
classes:SklearnDistribution
andSortedSpiral
.Update the end of
gromov_wasserstein.ipynb
with a sampling of the conditional distributions of the discrete Gromov-Wasserstein coupling instead of recovering a one-to-one matching associating each point with the one it is most coupled.