Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add PyTorch backend for soft-DTW #431

Merged

Conversation

YannCabanes
Copy link
Contributor

@YannCabanes YannCabanes commented Nov 3, 2022

This PR plans to make compatible the files soft_dtw_fast.py and softdtw_variants.py with the PyTorch backend.

We take inspiration from the following GitHub repository:
https://github.com/Sleepwalking/pytorch-softdtw/blob/master/soft_dtw.py

Github repository of Mehran Maghoumi on Soft-DTW using Pytorch with cuda:
https://github.com/Maghoumi/pytorch-softdtw-cuda/blob/master/soft_dtw_cuda.py

An introduction to Dynamic Time Warping can be found at:
https://rtavenar.github.io/blog/dtw.html

An introduction about the differentiability of DTW and the case of soft-DTW can be found at:
https://rtavenar.github.io/blog/softdtw.html

We also take inspiration from the python package geomstats [JMLR:v21:19-027] (https://github.com/geomstats/geomstats/) about ML in Riemannian manifolds as a source of inspiration to implement the multiple backends functions.

References

[JMLR:v21:19-027] Nina Miolane, Nicolas Guigui, Alice Le Brigant, Johan Mathe, Benjamin Hou, Yann Thanwerdas, Stefan Heyder, Olivier Peltre, Niklas Koep, Hadi Zaatiti, Hatem Hajri, Yann Cabanes, Thomas Gerald, Paul Chauchat, Christian Shewmake, Daniel Brooks, Bernhard Kainz, Claire Donnat, Susan Holmes and Xavier Pennec. Geomstats: A Python Package for Riemannian Geometry in Machine Learning, Journal of Machine Learning Research, 2020, volume 21, number 223, pages 1-9, http://jmlr.org/papers/v21/19-027.html

@YannCabanes
Copy link
Contributor Author

All checks of tslearn main branch have passed.

@YannCabanes YannCabanes changed the title [WIP] Add PyTorch backend for soft-dtw [WIP] Add PyTorch backend for soft-DTW Nov 3, 2022
@codecov-commenter
Copy link

codecov-commenter commented Nov 3, 2022

Codecov Report

Patch coverage: 90.41% and project coverage change: -1.52 ⚠️

Comparison is base (a091483) 94.53% compared to head (fef2486) 93.02%.

❗ Current head fef2486 differs from pull request most recent head 5f90b87. Consider uploading reports for the commit 5f90b87 to get more accurate results

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #431      +/-   ##
==========================================
- Coverage   94.53%   93.02%   -1.52%     
==========================================
  Files          62       67       +5     
  Lines        4743     5663     +920     
==========================================
+ Hits         4484     5268     +784     
- Misses        259      395     +136     
Impacted Files Coverage Δ
tslearn/metrics/sax.py 100.00% <ø> (ø)
tslearn/metrics/soft_dtw_fast.py 57.39% <30.98%> (-42.61%) ⬇️
tslearn/backend/pytorch_backend.py 77.20% <77.20%> (ø)
tslearn/metrics/softdtw_variants.py 93.33% <90.90%> (-4.51%) ⬇️
tslearn/clustering/kmeans.py 91.69% <91.66%> (ø)
tslearn/metrics/soft_dtw_loss_pytorch.py 92.64% <92.64%> (ø)
tslearn/backend/backend.py 93.54% <93.54%> (ø)
tslearn/backend/numpy_backend.py 93.93% <93.93%> (ø)
tslearn/utils/utils.py 93.85% <95.23%> (+0.17%) ⬆️
tslearn/metrics/dtw_variants.py 95.48% <96.76%> (-1.63%) ⬇️
... and 5 more

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@kushalkolar
Copy link
Contributor

Do you think that in general tslearn would benefit from having torch versions for metrics or models? DTW based metrics would certainly be a useful loss function to utilize in pytorch for time-series models. The modules you've written so far seem useful for implementing metrics beyond just sDTW.

@YannCabanes
Copy link
Contributor Author

Hello @kushalkolar,
Thank you for your interest.
You are right, having the PyTorch backend for the soft-DTW is only a beginning.
In the end, @rtavenar and I plan to have all of tslearn tools supported by NumPy, PyTorch, TensorFlow and then maybe Jax and additional backends.
We planned to start with only the PyTorch backend for the soft-DTW metrics, however it is already complicated because of the Numba @njit decorator which only supports the NumPy backend.

@YannCabanes
Copy link
Contributor Author

I am looking for ideas to convert the PyTorch input arrays to NumPy before using Numba:
https://stackoverflow.com/questions/63169760/how-can-i-use-numba-for-pytorch-tensors

@YannCabanes
Copy link
Contributor Author

YannCabanes commented Nov 28, 2022

My problem is the following: Numba only supports NumPy arrays.
I have created a decorator which converts inputs of a function to Numpy and then convert back the output to the original backend so I can write:

@convert_backend_to_numpy()
@njit(fastmath=True)
def function(args):
    ...

Then I can call this function using either a NumPy or a PyTorch input.
The output has the same type than the input.

@kushalkolar
Copy link
Contributor

Some quick thoughts: wouldn't it be better to make a pure torch implementation, at least for the most compute intense steps, than have lots back and forth with numpy? sDTW could be a great loss function for some use cases and I wonder if constant back and forth between numpy and torch Tensors (or jax etc. Tensors) could slow things down.

@YannCabanes
Copy link
Contributor Author

YannCabanes commented Nov 28, 2022

However, the @convert_backend_to_numpy() decorator does not work when a njit decorated function is calling another auxiliary function (for example _soft_dtw calling _softmin3) .
If a main function is decorated with jit, then its auxiliary functions also have to be decorated with jit and the convert_backend_to_numpy decorator can not precede the jit decorator of the auxiliary functions.

I could remove all convert_backend_to_numpy decorators from the auxiliary functions but then they will only support the NumPy backend.

@rtavenar, do you know about how we could do for the auxiliary jit decorated function to support the PyTorch backend?
Or should the auxiliary functions only support the NumPy backend?

@YannCabanes
Copy link
Contributor Author

Hello @kushalkolar,
Yes, I am also comparing the execution times between:

  • converting inputs from PyTorch to NumPy --> Use Numba --> convert back the outputs to PyTorch
  • use directly Pytorch
    It seems that for very short functions such as _softmin3, it is better to use directly PyTorch, but for longer functions it might be faster to convert to NumPy and use Numba.

@YannCabanes
Copy link
Contributor Author

Even if I choose to use directly the PyTorch backend, I still have a problem related to Numba.
I have created a decorator called @njit_if_numpy_backend which decorates a function only if the inputs use the NumPy backend.
However, I have a problem when a main jit decorated function is calling an auxiliary function.
If the main function is using the NumPy backend, then the auxiliary function also has to be decorated with jit.
But if the PyTorch backend is used in the main function, then the auxiliary function can not be jit decorated.

@YannCabanes
Copy link
Contributor Author

Numba does not support a backend as input argument of a njit decorated function as a backend does not correspond to a Numba type.
We obtain the following error:
Cannot determine Numba type of <class 'tslearn.backend.backend.Backend'>

Likewise, we can not write in a njit decorated function:

if backend is None:
    backend = Backend(data)

as Backend does not have a Numba type.
Even if backend is not None in the execution, this will raise an Error during the compilation.

@rtavenar
Copy link
Member

rtavenar commented Dec 5, 2022

Maybe you can just (at least for a start) have constant integer values that identify the backends. Would that help, or is the problem due to calling the Backend builder?

@YannCabanes
Copy link
Contributor Author

YannCabanes commented Dec 5, 2022

Hello @rtavenar, the problem is due to calling the Backend builder as it does not have a Numba type.
Elements which can be used in a Numba njit decorated function are mainly:

  • simple arithmetic Python operations (+, -, *, / ...),
  • loops,
  • NumPy functions,
  • jit or njit decorated auxiliary functions.

@YannCabanes
Copy link
Contributor Author

The tests have passed under MacOS Python 3.9 (skipping the failing tests):

tslearn/tests/test_metrics.py::test_gak PASSED                           [ 72%]
tslearn/tests/test_metrics.py::test_gamma_soft_dtw SKIPPED (Test failing
for MacOS with python3.9 (Segmentation fault))                           [ 73%]
tslearn/tests/test_metrics.py::test_symmetric_cdist SKIPPED (Test
failing for MacOS with python3.9 (Segmentation fault))                   [ 74%]

@YannCabanes
Copy link
Contributor Author

The tests have passed under MacOS Python 3.9 (skipping the failing tests).

1 similar comment
@YannCabanes
Copy link
Contributor Author

The tests have passed under MacOS Python 3.9 (skipping the failing tests).

@YannCabanes
Copy link
Contributor Author

I will now restore the last good commit bf7bb02 before investigating on the failing test under MacOS for Python 3.9 (Segmentation fault).

@YannCabanes YannCabanes force-pushed the add-pytorch-backend-for-softdtw branch from 29bed46 to bf7bb02 Compare July 2, 2023 06:08
@YannCabanes
Copy link
Contributor Author

The readthedocs test in failing:

/home/docs/checkouts/readthedocs.org/user_builds/tslearn/checkouts/431/docs/examples/classification/plot_shapelet_distances.py failed leaving traceback:
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/checkouts/431/docs/examples/classification/plot_shapelet_distances.py", line 53, in <module>
    shp_clf.fit(X_train, y_train)
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/envs/431/lib/python3.8/site-packages/tslearn-0.5.3.2-py3.8.egg/tslearn/shapelets/shapelets.py", line 444, in fit
    self._set_model_layers(X=X, ts_sz=sz, d=d, n_classes=n_labels)
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/envs/431/lib/python3.8/site-packages/tslearn-0.5.3.2-py3.8.egg/tslearn/shapelets/shapelets.py", line 710, in _set_model_layers
    shapelet_layers = [
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/envs/431/lib/python3.8/site-packages/tslearn-0.5.3.2-py3.8.egg/tslearn/shapelets/shapelets.py", line 711, in <listcomp>
    LocalSquaredDistanceLayer(
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/envs/431/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/envs/431/lib/python3.8/site-packages/tslearn-0.5.3.2-py3.8.egg/tslearn/shapelets/shapelets.py", line 135, in build
    self.kernel = self.add_weight(name='kernel',
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/envs/431/lib/python3.8/site-packages/tslearn-0.5.3.2-py3.8.egg/tslearn/shapelets/shapelets.py", line 106, in __call__
    shapelets = _kmeans_init_shapelets(self.X_,
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/envs/431/lib/python3.8/site-packages/tslearn-0.5.3.2-py3.8.egg/tslearn/shapelets/shapelets.py", line 89, in _kmeans_init_shapelets
    return TimeSeriesKMeans(n_clusters=n_shapelets,
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/envs/431/lib/python3.8/site-packages/tslearn-0.5.3.2-py3.8.egg/tslearn/clustering/kmeans.py", line 780, in fit
    self._fit_one_init(X_, x_squared_norms, rs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/tslearn/envs/431/lib/python3.8/site-packages/tslearn-0.5.3.2-py3.8.egg/tslearn/clustering/kmeans.py", line 629, in _fit_one_init
    self.cluster_centers_ = _kmeans_plusplus(
TypeError: _kmeans_plusplus() missing 1 required positional argument: 'sample_weight'

as well as the test on MacOS for Python 3.9:

tslearn/tests/test_metrics.py::test_gamma_soft_dtw PASSED                [ 61%]
tslearn/tests/test_metrics.py::test_symmetric_cdist Fatal Python error: Segmentation fault

Thread 0x0000000118698600 (most recent call first):
  File "/Users/runner/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/torch/_tensor_str.py", line 153 in __init__
  File "/Users/runner/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/torch/_tensor_str.py", line 327 in _tensor_str
  File "/Users/runner/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/torch/_tensor_str.py", line 567 in _str_intern
  File "/Users/runner/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/torch/_tensor_str.py", line 636 in _str
  File "/Users/runner/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/torch/_tensor.py", line 426 in __repr__
  File "/Users/runner/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/torch/_tensor.py", line 873 in __format__
  File "/Users/runner/work/1/s/tslearn/backend/backend.py", line 34 in instanciate_backend
  File "/Users/runner/work/1/s/tslearn/metrics/dtw_variants.py", line 646 in dtw
  File "/Users/runner/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/joblib/parallel.py", line 1784 in _get_sequential_output
  File "/Users/runner/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/joblib/parallel.py", line 1855 in __call__
  File "/Users/runner/work/1/s/tslearn/metrics/utils.py", line 92 in _cdist_generic
  File "/Users/runner/work/1/s/tslearn/metrics/dtw_variants.py", line 1624 in cdist_dtw
  File "/Users/runner/work/1/s/tslearn/tests/test_metrics.py", line 476 in test_symmetric_cdist

@YannCabanes
Copy link
Contributor Author

Only the docs/readthedocs.org:tslearn, tslearn-team.tslearn and tslearn-team.tslearn (macOS Python39) tests are failing.

@YannCabanes
Copy link
Contributor Author

The test for MacOS Python 3.9 has failed:

tslearn/tests/test_metrics.py::test_gamma_soft_dtw Fatal Python error: Segmentation fault

Thread 0x000000010c149600 (most recent call first):
  File "/Users/runner/work/1/s/tslearn/backend/pytorch_backend.py", line 172 in pdist
  File "/Users/runner/work/1/s/tslearn/metrics/softdtw_variants.py", line 313 in sigma_gak
  File "/Users/runner/work/1/s/tslearn/metrics/softdtw_variants.py", line 361 in gamma_soft_dtw
  File "/Users/runner/work/1/s/tslearn/tests/test_metrics.py", line 465 in test_gamma_soft_dtw

@YannCabanes
Copy link
Contributor Author

All checks have passed!

@YannCabanes
Copy link
Contributor Author

YannCabanes commented Jul 2, 2023

Hello @rtavenar and @johannfaouzi,

I did not find the reason of the Segmentation fault error on MacOS for Python 3.9, I have skipped the 2 failing tests on MacOS for Python 3.9:

  • test_symmetric_cdist --> The failing rate of this test if about 3/4
  • test_gamma_soft_dtw --> The failing rate of this test if about 1/5
    When these tests are failing, the error messages can be different from one fail to the other: the Segmentation fault can occur in different lines of the code.

Other tests might also fail with a very low failing rate.
I consider that this PR is ready to be merged.

Could you review my PR please?

@YannCabanes
Copy link
Contributor Author

All checks have passed.

@rtavenar
Copy link
Member

rtavenar commented Jul 3, 2023

This is ready for merge, congrats @YannCabanes on this huge effort!

@rtavenar rtavenar merged commit 7be1c45 into tslearn-team:main Jul 3, 2023
12 checks passed
@YannCabanes YannCabanes deleted the add-pytorch-backend-for-softdtw branch July 13, 2023 07:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants